OneIE代码详解,匹配论文详细解读(一)

目录

一、代码目录

二、配置文件

三、数据类型转化

四、图的定义

五、训练

1. 配置

1.1 解析命令行参数

1.2 设置GPU设备参数

2. 输出设置

3. 数据集设置

3.1 数据集设置

3.2 加载有效模板信息

3.3 计算训练、开发和测试集的批次数量

4. OneIE加载

4.1 初始化OneIE模型

4.2 加载预训练的 BERT 

4.3 设置模型是否使用GPU

5. 参数优化器

6. 模型状态

6.1 创建模型状态字典

6.2 初始化全局步数和全局特征最大步数

6.3 定义任务列表&初始化最佳验证集性能字典

7.开始训练


关于OneIE的论文解读,我在之前的一篇博客已经详细解释过了,这次我将代码和论文相结合来全面的了解OneIE具体的实施过程。

首先先下载对应的OneIE的代码,之前官网上的代码已经被下架了,也不知道为啥,我就去GitHub上找了一个看上去很完备的代码,感谢GerlinGreen,链接放在这里啦!点击这里

一、代码目录

bert文件夹中具有BERT的各种版本框架链接和BERT对应的已经预训练完的模型权重的链接

config文件夹中是关于本次OneIE项目的一些基本网络和资源调用情况的一些配置

input 文件夹中是已经标记好的xml格式的文件,一般使用brat标记

output 文件夹中是输出的一系列文件,文件格式是json的(如下),包括文档id,句子id,tokens,图中的实体,触发词,关系和角色

{
    "doc_id": "HC0003PYD",

    "sent_id": "HC0003PYD-0",

    "tokens": ["Obama", "Ignores", "North", "Korea", "in", "Address", "to", "Congress"],

    "graph": {

        "entities": [

            [0, 1, "PER", "NAM", 1.0],

            [2, 4, "GPE", "NAM", 0.4987634171811422],

            [7, 8, "ORG", "NAM", 1.0]

        ],

        "triggers": [],

        "relations": [],

        "roles": []
    }
}

output_cs文件夹中存放了将json转换为cs文件的输出结果,如何转换,只需要运行convert.py即可。

preprocessing文件夹中有以下文件

其中的文件都是对数据进行预处理:

  • process_dygiepp.py将DyGIE++格式的数据集转换为OneIE使用的格式
  • process_ace.py 将ACE2005格式的数据集转换为OneIE使用的格式
  • process_ere.py 将原始 ERE 数据集(LDC2015E29、LDC2015E68、LDC2015E78、LDC2015E107)的数据集转换为OneIE使用的格式

因为重点研究OneIE的使用,这里就不对格式转换文件的内容进行详细说明,如果之后有机会的话,我会再出一份博客。

resource文件夹主要是用于加载程序所需要的资源的一般是训练和交叉验证资源,里有两个子文件夹:splits和valid_patterns,还有几个.tsv文件

  • 其中splits文件夹里放了ACE格式的不同语言文档的词序号,应该是分隔文档用的文件
  • Valid_patterns文件夹里放了实体识别、关系识别、事件识别分别对应的词的json格式数据
  • 剩下几个tsv文件中分别放置了实体、事件、关系和角色的来源链接

二、配置文件

config.py

class Config(object):
    def __init__(self, **kwargs):
        self.coref = kwargs.pop('coref', True)
        # bert
        self.bert_model_name = kwargs.pop('bert_model_name', 'bert-large-cased')
        self.bert_cache_dir = kwargs.pop('bert_cache_dir', None)
        self.extra_bert = kwargs.pop('extra_bert', -1)
        self.use_extra_bert = kwargs.pop('use_extra_bert', False)
        # global features
        self.use_global_features = kwargs.get('use_global_features', False)
        self.global_features = kwargs.get('global_features', [])
        # model
        self.multi_piece_strategy = kwargs.pop('multi_piece_strategy', 'first')
        self.bert_dropout = kwargs.pop('bert_dropout', .5)
        self.linear_dropout = kwargs.pop('linear_dropout', .4)
        self.linear_bias = kwargs.pop('linear_bias', True)
        self.linear_activation = kwargs.pop('linear_activation', 'relu')
        self.entity_hidden_num = kwargs.pop('entity_hidden_num', 150)
        self.mention_hidden_num = kwargs.pop('mention_hidden_num', 150)
        self.event_hidden_num = kwargs.pop('event_hidden_num', 600)
        self.relation_hidden_num = kwargs.pop('relation_hidden_num', 150)
        self.role_hidden_num = kwargs.pop('role_hidden_num', 600)
        self.use_entity_type = kwargs.pop('use_entity_type', False)
        self.beam_size = kwargs.pop('beam_size', 5)
        self.beta_v = kwargs.pop('beta_v', 2)
        self.beta_e = kwargs.pop('beta_e', 2)
        self.relation_mask_self = kwargs.pop('relation_mask_self', True)
        self.relation_directional = kwargs.pop('relation_directional', False)
        self.symmetric_relations = set(kwargs.pop('symmetric_relations', ['PER-SOC']))
        # files
        self.train_file = kwargs.pop('train_file', None)
        self.dev_file = kwargs.pop('dev_file', None)
        self.test_file = kwargs.pop('test_file', None)
        self.valid_pattern_path = kwargs.pop('valid_pattern_path', None)
        self.log_path = kwargs.pop('log_path', None)
        # training
        self.accumulate_step = kwargs.pop('accumulate_step', 1)
        self.batch_size = kwargs.pop('batch_size', 10)
        self.eval_batch_size = kwargs.pop('eval_batch_size', 5)
        self.max_epoch = kwargs.pop('max_epoch', 50)
        self.learning_rate = kwargs.pop('learning_rate', 1e-3)
        self.bert_learning_rate = kwargs.pop('bert_learning_rate', 1e-5)
        self.weight_decay = kwargs.pop('weight_decay', 0.001)
        self.bert_weight_decay = kwargs.pop('bert_weight_decay', 0.00001)
        self.warmup_epoch = kwargs.pop('warmup_epoch', 5)
        self.grad_clipping = kwargs.pop('grad_clipping', 5.0)
        # others
        self.use_gpu = kwargs.pop('use_gpu', True)
        self.gpu_device = kwargs.pop('gpu_device', -1)

    @classmethod
    def from_dict(cls, dict_obj):
        """Creates a Config object from a dictionary.
        Args:
            dict_obj (Dict[str, Any]): a dict where keys are
        """
        config = cls()
        for k, v in dict_obj.items():
            setattr(config, k, v)
        return config

    @classmethod
    def from_json_file(cls, path):
        with open(path, 'r', encoding='utf-8') as r:
            return cls.from_dict(json.load(r))

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output

    def save_config(self, path):
        """Save a configuration object to a file.
        :param path (str): path to the output file or its parent directory.
        """
        if os.path.isdir(path):
            path = os.path.join(path, 'config.json')
        print('Save config to {}'.format(path))
        with open(path, 'w', encoding='utf-8') as w:
            w.write(json.dumps(self.to_dict(), indent=2,
                               sort_keys=True))
    @property
    def bert_config(self):
        if self.bert_model_name.startswith('bert-'):
            return BertConfig.from_pretrained(self.bert_model_name,
                                              cache_dir=self.bert_cache_dir)
        elif self.bert_model_name.startswith('roberta-'):
            return RobertaConfig.from_pretrained(self.bert_model_name,
                                                 cache_dir=self.bert_cache_dir)
        elif self.bert_model_name.startswith('xlm-roberta-'):
            return XLMRobertaConfig.from_pretrained(self.bert_model_name,
                                                    cache_dir=self.bert_cache_dir)
        else:
            raise ValueError('Unknown model: {}'.format(self.bert_model_name))
  1. __init__: 该方法用于初始化配置对象。接受关键字参数 **kwargs,包括了模型训练所需的各种参数,如 BERT 模型名称、缓存目录、全局特征使用与否、训练文件路径等等。

__init__ 方法中,各个参数的初始化含义如下:

  • coref: 是否使用共指关系(coreference),默认为 True

  • bert_model_name: BERT 模型的名称,默认为 'bert-large-cased'

  • bert_cache_dir: BERT 模型的缓存目录,默认为 None

  • extra_bert: 额外的 BERT 模型编号,默认为 -1

  • use_extra_bert: 是否使用额外的 BERT 模型,默认为 False

  • use_global_features: 是否使用全局特征,默认为 False

  • global_features: 全局特征的列表,默认为空列表 []

  • multi_piece_strategy: 处理多词片段的策略,可选值为 'first''average',默认为 'first'

  • bert_dropout: BERT 模型的 dropout 概率,默认为 0.5

  • linear_dropout: 线性层的 dropout 概率,默认为 0.4

  • linear_bias: 是否使用线性层的偏置项,默认为 True

  • linear_activation: 线性层的激活函数,默认为 'relu'

  • entity_hidden_num: 实体隐藏层的大小,默认为 150

  • mention_hidden_num: 提及隐藏层的大小,默认为 150

  • event_hidden_num: 事件隐藏层的大小,默认为 600

  • relation_hidden_num: 关系隐藏层的大小,默认为 150

  • role_hidden_num: 角色隐藏层的大小,默认为 600

  • use_entity_type: 是否使用实体类型,默认为 False

  • beam_size: 波束搜索的大小,默认为 5

  • beta_v: 波束搜索中节点的大小,默认为 2

  • beta_e: 波束搜索中边的大小,默认为 2

  • relation_mask_self: 关系是否遮蔽自身,默认为 True

  • relation_directional: 关系是否有方向性,默认为 False

  • symmetric_relations: 对称关系的集合,默认为 set(['PER-SOC'])

  • train_file: 训练数据文件路径,默认为 None

  • dev_file: 开发数据文件路径,默认为 None

  • test_file: 测试数据文件路径,默认为 None

  • valid_pattern_path: 验证模式的路径,默认为 None

  • log_path: 日志文件路径,默认为 None

  • accumulate_step: 累积梯度的步数,默认为 1

  • batch_size: 训练时的批处理大小,默认为 10

  • eval_batch_size: 评估时的批处理大小,默认为 5

  • max_epoch: 最大训练轮数,默认为 50

  • learning_rate: 学习率,默认为 1e-3

  • bert_learning_rate: BERT 模型的学习率,默认为 1e-5

  • weight_decay: 权重衰减,默认为 0.001

  • bert_weight_decay: BERT 模型的权重衰减,默认为 0.00001

  • warmup_epoch: 学习率预热的轮数,默认为 5

  • grad_clipping: 梯度裁剪的阈值,默认为 5.0

  • use_gpu: 是否使用 GPU,默认为 True

  • gpu_device: GPU 设备的编号,默认为 -1

  1. from_dict: 该方法通过传入一个字典创建一个配置对象。

  2. from_json_file: 该方法通过读取一个 JSON 文件创建一个配置对象。

  3. to_dict: 将配置对象转化为字典。

  4. save_config: 将配置保存到文件。

  5. bert_config: 返回一个用于 BERT 模型的配置对象 (BertConfig, RobertaConfig, 或 XLMRobertaConfig)。

在这个配置类中,可以看到许多模型训练所需的参数,例如 BERT 相关的参数、模型的隐藏层大小、学习率、批处理大小等。这些参数可用于初始化模型并进行训练。

三、数据类型转化

在convert.py中我们将训练所需要使用input和output中的json数据先转化成我们需要的cs格式文件,如需进行训练需要先运行convert.py

四、图的定义

在graph.py中我们定义了一个类,用于后续我们需要的图表示,其中图的类的含义就是论文中提到的

G = (V,E)

其中,

v_i = <a_i,b_i,l_i> \in V

e_{ij} = <i,j,l_{ij}> \in E

五、训练

在这一部分我要穿插着详细讲述各个函数文件,但是一个基本的思路是按照程序运行的思路去叙述OneIE,并且我也会将各个部分对应到OneIE论文中的公式,这样可以更好的实际去感受代码运行的流程,也是解决了前后输入之间的疑问。

在train.py中我们开始了对模型的训练

1. 配置

1.1 解析命令行参数

这样让文件可以在shell面板直接通过调用python train.py -c/--config 去调用该文件。

parser = ArgumentParser()
parser.add_argument('-c', '--config', default='config/example.json')  # 设置配置文件路径的命令行参数
args = parser.parse_args()
config = Config.from_json_file(args.config)  # 从JSON文件中加载配置
# print(config.to_dict())

其默认路径是config/example.json文件夹中的配置,如下:

{
//    bert模型的配置
    "bert_model_name": "bert-large-cased",// 指定要使用的BERT模型,这种情况下是bert-large-cased
    "bert_cache_dir": "<BERT_CACHE_DIR>", //用于缓存bert模型文件的目录
//    模型的架构和特性
    "multi_piece_strategy": "average", // 处理多词标记策略
    "bert_dropout": 0.5, // bert层的丢弃率
    "use_extra_bert": true, // 使用额外的BERT层
    "extra_bert": -3, // 正如OneIE文章说的那样,倒数第三层使用了别的层

    "use_global_features": true,
    "global_features": [],
    "global_warmup": 0,

//    神经网络框架
    "linear_dropout": 0.4, // 线性层丢弃率0.4
    "linear_bias": true, // 存在线性偏差
//    网络中隐藏层设置
    "entity_hidden_num": 150,// 实体隐藏层数目
    "mention_hidden_num": 150,
    "event_hidden_num": 600,
    "relation_hidden_num": 150,
    "role_hidden_num": 600,
    "use_entity_type": true, // 是否使用实体类型信息
//    波束搜索法推断设置
    "beam_size": 20, // 推断时波束搜索大小
    "beta_v": 2, //控制波束搜索的边参数
    "beta_e": 2, //控制波束搜索的节点参数
    "relation_mask_self": true, //关系掩码设置
    "relation_directional": false, // 关系方向性设置关闭
    "symmetric_relations": ["PER-SOC"], // 使用对称的关系列表
//  数据和文件路径
    "train_file": "<TRAIN_FILE_PATH>",
    "dev_file": "<DEV_FILE_PATH>",
    "test_file": "<TEST_FILE_PATH>",
    "log_path": "<OUTPUT_DIR>",
    "valid_pattern_path": "<VALID_PATTERN_DIR>",
    "ignore_title": false,
    "ignore_first_header": false,
//  训练的一些超参数设置
    "accumulate_step": 1,
    "batch_size": 10,
    "eval_batch_size": 10,
    "max_epoch": 60,
    "learning_rate": 1e-3,
    "bert_learning_rate": 1e-5,
    "weight_decay": 1e-3,
    "bert_weight_decay": 1e-5,
    "warmup_epoch": 5,
    "grad_clipping": 5.0,
//  GPU设置
    "use_gpu": true,
    "gpu_device": 0 //gpu设备的索引,一般个人电脑或者服务器就一个所以设置为0,如果有多GPU服务器可是设置别的,或者同时使用
  }

关于配置中的信息我都加以了注释,应该很方便读者理解。这里json文件中的配置,在之后会覆盖config.py中的一些默认设置,可以根据自己的需要去更改相应的配置。

1.2 设置GPU设备参数

if use_gpu and config.gpu_device >= 0:
    torch.cuda.set_device(config.gpu_device)

2. 输出设置

这里这一部分仅仅是对输出文件的初始化,用于创建输出文件的路径,创建输出文件等。

timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())  # 生成时间戳字符串
output_dir = os.path.join(config.log_path, timestamp)  # 拼接输出目录路径
# 如果输出目录不存在,则创建
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# 拼接日志文件路径
log_file = os.path.join(output_dir, 'log.txt')
# 创建并写入配置信息到日志文件
with open(log_file, 'w', encoding='utf-8') as w:
    w.write(json.dumps(config.to_dict()) + '\n')
    # 打印日志文件路径
    print('Log file: {}'.format(log_file))

# 定义其他输出文件路径
best_role_model = os.path.join(output_dir, 'best.role.mdl')
dev_result_file = os.path.join(output_dir, 'result.dev.json')
test_result_file = os.path.join(output_dir, 'result.test.json')

3. 数据集设置

3.1 数据集设置

这一部分主要用于配置训练所需要使用的训练集、验证集和测试集的数据,其中数据路径也是由自己设置在config/example.json或者你自己定义的config.json文件中设置

# 获取BERT模型名称和配置
model_name = config.bert_model_name
# 使用Hugging Face Transformers库中的BertTokenizer加载预训练的分词器
tokenizer = BertTokenizer.from_pretrained(model_name,
                                          cache_dir=config.bert_cache_dir,
                                          do_lower_case=False)
# 创建训练集、验证集和测试集的数据对象
train_set = IEDataset(config.train_file, gpu=use_gpu,
                      relation_mask_self=config.relation_mask_self,
                      relation_directional=config.relation_directional,
                      symmetric_relations=config.symmetric_relations,
                      ignore_title=config.ignore_title)
dev_set = IEDataset(config.dev_file, gpu=use_gpu,
                    relation_mask_self=config.relation_mask_self,
                    relation_directional=config.relation_directional,
                    symmetric_relations=config.symmetric_relations)
test_set = IEDataset(config.test_file, gpu=use_gpu,
                     relation_mask_self=config.relation_mask_self,
                     relation_directional=config.relation_directional,
                     symmetric_relations=config.symmetric_relations)
# 生成词汇表平接三个数据集
vocabs = generate_vocabs([train_set, dev_set, test_set])

# 将训练集、验证集和测试集中的文本数据数值化(tokenization)
train_set.numberize(tokenizer, vocabs)
dev_set.numberize(tokenizer, vocabs)
test_set.numberize(tokenizer, vocabs)

3.2 加载有效模板信息

valid_patterns = load_valid_patterns(config.valid_pattern_path, vocabs)

这里对应着论文中的Table 1表格,当然可以根据自己的需要加载适合的模板

3.3 计算训练、开发和测试集的批次数量

batch_num = len(train_set) // config.batch_size
dev_batch_num = len(dev_set) // config.eval_batch_size + \
                (len(dev_set) % config.eval_batch_size != 0)
test_batch_num = len(test_set) // config.eval_batch_size + \
                 (len(test_set) % config.eval_batch_size != 0)

4. OneIE加载

4.1 初始化OneIE模型

# 创建 OneIE 模型实例,传入配置、词汇表和有效模式
model = OneIE(config, vocabs, valid_patterns)

这里,我们看一下model.py中如何对OneIE进行初始化的

# 模型参数初始化
    def __init__(self,
                 config,
                 vocabs,
                 valid_patterns=None):
        super().__init__()

        # vocabularies
        self.vocabs = vocabs
        self.entity_label_stoi = vocabs['entity_label']
        self.trigger_label_stoi = vocabs['trigger_label']
        self.mention_type_stoi = vocabs['mention_type']
        self.entity_type_stoi = vocabs['entity_type']
        self.event_type_stoi = vocabs['event_type']
        self.relation_type_stoi = vocabs['relation_type']
        self.role_type_stoi = vocabs['role_type']
        self.entity_label_itos = {i: s for s, i in self.entity_label_stoi.items()}
        self.trigger_label_itos = {i: s for s, i in self.trigger_label_stoi.items()}
        self.entity_type_itos = {i: s for s, i in self.entity_type_stoi.items()}
        self.event_type_itos = {i: s for s, i in self.event_type_stoi.items()}
        self.relation_type_itos = {i: s for s, i in self.relation_type_stoi.items()}
        self.role_type_itos = {i: s for s, i in self.role_type_stoi.items()}
        self.entity_label_num = len(self.entity_label_stoi)
        self.trigger_label_num = len(self.trigger_label_stoi)
        self.mention_type_num = len(self.mention_type_stoi)
        self.entity_type_num = len(self.entity_type_stoi)
        self.event_type_num = len(self.event_type_stoi)
        self.relation_type_num = len(self.relation_type_stoi)
        self.role_type_num = len(self.role_type_stoi)
        self.valid_relation_entity = set()
        self.valid_event_role = set()
        self.valid_role_entity = set()
        if valid_patterns:
            self.valid_event_role = valid_patterns['event_role']
            self.valid_relation_entity = valid_patterns['relation_entity']
            self.valid_role_entity = valid_patterns['role_entity']
        self.relation_directional = config.relation_directional
        self.symmetric_relations = config.symmetric_relations
        self.symmetric_relation_idxs = {self.relation_type_stoi[r]
                                        for r in self.symmetric_relations}

        # BERT encoder
        bert_config = config.bert_config
        bert_config.output_hidden_states = True
        self.bert_dim = bert_config.hidden_size
        self.extra_bert = config.extra_bert
        self.use_extra_bert = config.use_extra_bert
        if self.use_extra_bert:
            self.bert_dim *= 2
        self.bert = BertModel(bert_config)
        self.bert_dropout = nn.Dropout(p=config.bert_dropout)
        self.multi_piece = config.multi_piece_strategy
        # local classifiers
        self.use_entity_type = config.use_entity_type
        self.binary_dim = self.bert_dim * 2
        linear_bias = config.linear_bias
        linear_dropout = config.linear_dropout
        entity_hidden_num = config.entity_hidden_num
        mention_hidden_num = config.mention_hidden_num
        event_hidden_num = config.event_hidden_num
        relation_hidden_num = config.relation_hidden_num
        role_hidden_num = config.role_hidden_num
        role_input_dim = self.binary_dim + (self.entity_type_num if self.use_entity_type else 0)
        self.entity_label_ffn = nn.Linear(self.bert_dim, self.entity_label_num,
                                          bias=linear_bias)
        self.trigger_label_ffn = nn.Linear(self.bert_dim, self.trigger_label_num,
                                           bias=linear_bias)
        self.entity_type_ffn = Linears([self.bert_dim, entity_hidden_num,
                                        self.entity_type_num],
                                       dropout_prob=linear_dropout,
                                       bias=linear_bias,
                                       activation=config.linear_activation)
        self.mention_type_ffn = Linears([self.bert_dim, mention_hidden_num,
                                         self.mention_type_num],
                                        dropout_prob=linear_dropout,
                                        bias=linear_bias,
                                        activation=config.linear_activation)
        self.event_type_ffn = Linears([self.bert_dim, event_hidden_num,
                                       self.event_type_num],
                                      dropout_prob=linear_dropout,
                                      bias=linear_bias,
                                      activation=config.linear_activation)
        self.relation_type_ffn = Linears([self.binary_dim, relation_hidden_num,
                                          self.relation_type_num],
                                         dropout_prob=linear_dropout,
                                         bias=linear_bias,
                                         activation=config.linear_activation)
        self.role_type_ffn = Linears([role_input_dim, role_hidden_num,
                                      self.role_type_num],
                                     dropout_prob=linear_dropout,
                                     bias=linear_bias,
                                     activation=config.linear_activation)
        # global features
        self.use_global_features = config.use_global_features
        self.global_features = config.global_features
        self.global_feature_maps = generate_global_feature_maps(vocabs, valid_patterns)
        self.global_feature_num = sum(len(m) for k, m in self.global_feature_maps.items()
                                      if k in self.global_features or
                                      not self.global_features)
        self.global_feature_weights = nn.Parameter(
            torch.zeros(self.global_feature_num).fill_(-0.0001))
        # decoder
        self.beam_size = config.beam_size
        self.beta_v = config.beta_v
        self.beta_e = config.beta_e
        # loss functions
        self.entity_criteria = torch.nn.CrossEntropyLoss()
        self.event_criteria = torch.nn.CrossEntropyLoss()
        self.mention_criteria = torch.nn.CrossEntropyLoss()
        self.relation_criteria = torch.nn.CrossEntropyLoss()
        self.role_criteria = torch.nn.CrossEntropyLoss()
        # others
        self.entity_crf = CRF(self.entity_label_stoi, bioes=False)
        self.trigger_crf = CRF(self.trigger_label_stoi, bioes=False)
        self.pad_vector = nn.Parameter(torch.randn(1, 1, self.bert_dim))

首先,它初始化了模型的词汇表(vocabs)和各个标签的映射关系,同时定义了标签的数量等信息。还设置了一些关于关系方向性和对称性的配置。

接着,初始化了BERT编码器,包括BERT的维度、是否使用额外的BERT、BERT的输出维度等。还初始化了BERT的线性层和dropout层。

之后,定义了局部分类器的结构,包括实体标签、触发标签、实体类型、提及类型、事件类型、关系类型和角色类型的线性层。

然后,根据配置初始化了全局特征相关的组件,包括是否使用全局特征、全局特征的类型、全局特征的映射等。

接下来,定义了解码器的一些参数,包括波束搜索的大小\theta\beta_v\beta_e

然后,初始化了损失函数,包括实体、事件、提及、关系和角色的交叉熵损失函数。

最后,初始化了CRF(条件随机场)层,用于对实体标签和触发标签进行建模。同时,定义了一个用于填充的参数矩阵。

看上去很复杂其实就简单理解为将所有网络中的参数随机或者按照一定的规则进行初始化就好。

4.2 加载预训练的 BERT 

# 加载预训练的 BERT 模型参数到 OneIE 模型中
model.load_bert(model_name, cache_dir=config.bert_cache_dir)

同样我们看一下model.py中对于预训练的BERT如何加载的

    def load_bert(self, name, cache_dir=None):
        """Load the pre-trained BERT model (used in training phrase)
        :param name (str): pre-trained BERT model name
        :param cache_dir (str): path to the BERT cache directory
        """
        print('Loading pre-trained BERT model {}'.format(name))
        self.bert = BertModel.from_pretrained(name,
                                              cache_dir=cache_dir,
                                              output_hidden_states=True)

这里其实是从Transformer包中直接加载了之前Google为我们训练好的BERT模型。

4.3 设置模型是否使用GPU

if use_gpu:
    model.cuda(device=config.gpu_device)

5. 参数优化器

param_groups = [
    {
        'params': [p for n, p in model.named_parameters() if n.startswith('bert')],
        'lr': config.bert_learning_rate, 'weight_decay': config.bert_weight_decay
    },
    {
        'params': [p for n, p in model.named_parameters() if not n.startswith('bert')
                   and 'crf' not in n and 'global_feature' not in n],
        'lr': config.learning_rate, 'weight_decay': config.weight_decay
    },
    {
        'params': [p for n, p in model.named_parameters() if not n.startswith('bert')
                   and ('crf' in n or 'global_feature' in n)],
        'lr': config.learning_rate, 'weight_decay': 0
    }
]
# 使用 AdamW 优化器,传入参数组
optimizer = AdamW(params=param_groups)
# 使用学习率预热和线性调度器,设置预热步数和总步数
schedule = get_linear_schedule_with_warmup(optimizer,
                                           num_warmup_steps=batch_num * config.warmup_epoch,
                                           num_training_steps=batch_num * config.max_epoch)

AdamW 优化器是对 Adam 优化器的一个变种,主要用于解决 Adam 在权重衰减(weight decay)方面的一些问题。在标准的 Adam 优化器中,权重衰减项会直接添加到每个权重的梯度上,这可能导致不同层的权重在更新过程中被调整得不均匀。

6. 模型状态

6.1 创建模型状态字典

# 创建模型状态字典,包含模型参数、配置、词汇表和有效模式
state = dict(model=model.state_dict(),
             config=config.to_dict(),
             vocabs=vocabs,
             valid=valid_patterns)

6.2 初始化全局步数和全局特征最大步数

global_step = 0
global_feature_max_step = int(config.global_warmup * batch_num) + 1
print('global feature max step:', global_feature_max_step)

6.3 定义任务列表&初始化最佳验证集性能字典

# 定义任务列表
tasks = ['entity', 'trigger', 'relation', 'role']
# 初始化最佳验证集性能字典
best_dev = {k: 0 for k in tasks}

到这里都是为了后续存储使用模型输出参数而做准备

7.开始训练

for epoch in range(config.max_epoch):
    print('Epoch: {}'.format(epoch))

从这里开始进入循环开始训练,为了详细讲述代码,后续代码将不嵌套循环讲解。

由于篇幅限制,后面的代码讲解可以关注我,进入我的主页查看《OneIE代码详解,匹配论文详细解读(一)》

  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
对于surf特征匹配代码详解,以下是一个示例: ```python import cv2 import numpy as np # 读取两张图片 img1 = cv2.imread('image1.jpg', cv2.IMREAD_GRAYSCALE) img2 = cv2.imread('image2.jpg', cv2.IMREAD_GRAYSCALE) # 创建SURF对象 surf = cv2.xfeatures2d.SURF_create() # 检测关键点和描述符 kp1, des1 = surf.detectAndCompute(img1, None) kp2, des2 = surf.detectAndCompute(img2, None) # 创建BFMatcher对象 bf = cv2.BFMatcher() # 使用KNN匹配算法,返回k个最佳匹配 matches = bf.knnMatch(des1, des2, k=2) # 应用比例测试,保留好的匹配 good_matches = [] for m, n in matches: if m.distance < 0.75 * n.distance: good_matches.append(m) # 绘制匹配结果 result = cv2.drawMatches(img1, kp1, img2, kp2, good_matches, None, flags=2) # 显示结果 cv2.imshow("SURF Matching", result) cv2.waitKey(0) cv2.destroyAllWindows() ``` 代码详解: 1. 导入必要的库:`cv2`用于图像处理,`numpy`用于矩阵操作。 2. 读取两张待匹配的灰度图像。 3. 创建SURF对象,通过`cv2.xfeatures2d.SURF_create()`创建。 4. 使用SURF对象分别检测关键点和计算描述符,通过`detectAndCompute()`方法实现。 5. 创建BFMatcher对象,用于进行特征匹配。 6. 使用KNN匹配算法,通过`bf.knnMatch()`方法进行特征匹配,返回k个最佳匹配。 7. 应用比例测试,保留好的匹配,使用0.75的阈值。 8. 绘制匹配结果,通过`cv2.drawMatches()`方法实现。 9. 显示结果,通过`cv2.imshow()`方法显示图像,`cv2.waitKey()`等待按键响应,`cv2.destroyAllWindows()`关闭窗口。 这段代码实现了SURF特征提取和匹配的过程,可以用于在两张图像中寻找相似的特征点并进行匹配

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

星宇星静

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值