【keras-bert 学习笔记】2. 保存、加载预训练模型,在预训练模型上添加层做监督训练(fine tune)

这篇博客主要介绍了如何使用Keras-BERT进行预训练模型的保存,并探讨了加载预训练模型后,如何添加额外层进行监督训练(fine-tune)的过程。
摘要由CSDN通过智能技术生成

1.预训练模型,并保存

import os
import tensorflow as tf
from keras_bert import (get_model, compile_model, get_base_dict, gen_batch_inputs)
from indoor_location.utils import get_sentence_pairs


seqence_len = 26  #有效的ap数量
pretrain_datafile_name = "..\\data\\sampleset_data\\trainset_day20-1-8_points20_average_interval_500ms.csv"
MODEL_DIR = "..\\model\\"
pretrained_model_path = MODEL_DIR + "pretrained_bert1.h5"


def bert_indoorlocation_pretrain():

    # 准备训练集数据和验证集数据
    sentence_pairs = get_sentence_pairs(pretrain_datafile_name)
    token_dict = get_base_dict()
    for pairs in sentence_pairs:
        for token in pairs[0] + pairs[1]:
            if token not in token_dict:
                token_dict[token] = len(token_dict)
    token_list = list(token_dict.keys())

    x_train, y_train = gen_batch_inputs(
        sentence_pairs,
        token_dict,
        token_list,
        seq_len=seqence_len,
        mask_rate=0.3,
        swap_sentence_rate=1.0,
    )
    x_test, y_test = gen_batch_inputs(
        sentence_pairs,
        token_dict,
        token_list,
        seq_len=seqence_len,
        mask_rate=0.3,
        swap_sentence_rate=1.0,
    )

    config = tf.ConfigProto(allow_soft_placement=True)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    config.gpu_options.allow_growth = True
    
    # 创建session
    with tf.Session(config=config) as session:

        # 构建模型
        model = get_model(
            token_num=len(token_dict),
            head_num=2,
            transformer_num=2,
            embed_dim=12,
            feed_forward_dim=100,
            seq_len=seqence_len,
            pos_num=seqence_len,
            dropout_rate=0.05,
            attention_activation='gelu',
        )

        # 设置模型
        print("compiling model .....")
        compile_model(
            model,
            learning_rate=1e-3,
            decay_steps=30000,
            warmup_steps=10000,
            weight_decay=1e-3,
        )
        model.summary()
       
        # 训练模型
        print("training network...")
        H = model.fit(x_train, y_train, validation_data=(x_test, y_test),
                      batch_size=32, epochs=1, verbose=2)
        # 保存模型
        model.save(pretrained_model_path) # 这里保存了模型的结构和参数变量的权重值到h5文件了


bert_indoorlocation_pretrain()

输出结果:

Using TensorFlow backend.
2020-04-08 13:27:41.211842: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1432] Found device 0 with properties: 
name: GeForce RTX 2080 Ti major: 7 minor: 5 memoryClockRate(GHz): 1.545
pciBusID: 0000:01:00.0
totalMemory: 11.00GiB freeMemory: 9.03GiB
2020-04-08 13:27:41.212025: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1511] Adding visible gpu devices: 0
compiling model .....
2020-04-08 13:27:42.030384: I tensorflow/core/common_runtime/gpu/gpu_device.cc:982] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-04-08 13:27:42.030481: I tensorflow/core/common_runtime/gpu/gpu_device.cc:988]      0 
2020-04-08 13:27:42.030536: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1001] 0:   N 
2020-04-08 13:27:42.030685: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 8712 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5)
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input-Token (InputLayer)        (None, 26)           0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, 26)           0                                            
__________________________________________________________________________________________________
Embedding-Token (TokenEmbedding [(None, 26, 12), (57 684         Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, 26, 12)       24          Input-Segment[0][0]              
__________________________________________________________________________________________________
Embedding-Token-Segment (Add)   (None, 26, 12)       0           Embedding-Token[0][0]            
                                                                 Embedding-Segment[0][0]          
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, 26, 12)       312         Embedding-Token-Segment[0][0]    
__________________________________________________________________________________________________
Embedding-Dropout (Dropout)     (None, 26, 12)       0           Embedding-Position[0][0]         
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, 26, 12)       24          Embedding-Dropout[0][0]          
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 26, 12)       624         Embedding-Norm[0][0]             
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 26, 12)       0           Encoder-1-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 26, 12)       0           Embedding-Norm[0][0]             
                                                                 Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 26, 12)       24          Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward (FeedForw (None, 26, 12)       2512        Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward-Dropout ( (None, 26, 12)       0           Encoder-1-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-1-FeedForward-Add (Add) (None, 26, 12)       0           Encoder-1-MultiHeadSelfAttention-
                                                                 Encoder-1-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-1-FeedForward-Norm (Lay (None, 26, 12)       24          Encoder-1-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 26, 12)       624         Encoder-1-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 26, 12)       0           Encoder-2-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 26, 12)       0           Encoder-1-FeedForward-Norm[0][0] 
                                                                 Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 26, 12)       24          Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward (FeedForw (None, 26, 12)       2512        Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward-Dropout ( (None, 26, 12)       0           Encoder-2-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-2-FeedForward-Add (Add) (None, 26, 12)       0           Encoder-2-
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值