1.argparse四步走
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
args = parser.parse_args()
2.设置对应数据集参数
if dataset == 'WB':
n_heads = 2
head_dims = 128
num_layers = 2
lr = 0.0007
attn_type = 'adatrans'
n_epochs = 50
3.cache缓存数据
from fastNLP import cache_results
name = 'caches/{}_{}_{}_{}_{}_{}_{}.pkl'.format(dataset, model_type, encoding_type, normalize_embed, args.pos_th,
args.dep_th, args.chunk_th)
@cache_results(name, _refresh=False)
4.加载数据
from fastNLP.embeddings import StaticEmbedding, BertEmbedding, StackEmbedding
from modules.pipe import CNNERPipe
def load_data():
paths = {
'train':'data/{}/train.txt'.format(dataset),
"dev":'data/{}/dev.txt'.format(dataset),
"test":'data/{}/test.txt'.format(dataset)
}
data_bundle = CNNERPipe(bigrams=True, trigrams=True, encoding_type=encoding_type).process_from_file(paths)
embed = StaticEmbedding(data_bundle.get_vocab('chars'),
model_dir_or_name='data/gigaword_chn.all.a2b.uni.ite50.vec',
min_freq=1, only_norm_found_vector=normalize_embed, word_dropout=0.01, dropout=0.3)
bert_embed = BertEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name=args.bert_model, layers='-1',
pool_method=args.pool_method, word_dropout=0, dropout=0.5, include_cls_sep=False,
pooled_cls=True, requires_grad=False, auto_truncate=False)
embed = StackEmbedding([embed, tencent_embed, bert_embed], dropout=0, word_dropout=0.02)
5.导入模型
model = TENER(tag_vocab=data_bundle.get_vocab('target'), embed=embed, num_layers=num_layers,
d_model=d_model, n_head=n_heads,
feedforward_dim=dim_feedforward, dropout=args.trans_dropout,
after_norm=after_norm, attn_type=attn_type,
bi_embed=bi_embed,
bi_embed2=bi_embed2,
tri_embed = tri_embed,
tri_embed2 = tri_embed2,
tri_embed3 = tri_embed3,
fc_dropout=fc_dropout,
pos_embed=pos_embed,
scale=attn_type=='transformer',
use_knowledge=True,
feature2count=feature2count,
vocab_size=vocab_size,
feature_vocab_size=feature_vocab_size,
kv_attn_type=args.kv_attn_type,
memory_dropout=args.memory_dropout,
fusion_dropout=args.fusion_dropout,
fusion_type=args.fusion_type,
highway_layer=args.highway_layer,
key_embed_dropout=args.key_embed_dropout,
knowledge_type=args.knowledge_type,
use_zen=args.zen_model!=""
)
6.设置优化器
from torch import optim
if args.optim_type == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))
7.Callbacks的使用
from fastNLP import Trainer, GradientClipCallback, WarmupCallback
callbacks = []
clip_callback = GradientClipCallback(clip_type='value', clip_value=5)
evaluate_callback = EvaluateCallback(data=data_bundle.get_dataset('test'),
use_knowledge=True,
knowledge_type=args.knowledge_type,
pos_th=args.pos_th,
dep_th=args.dep_th,
chunk_th=args.chunk_th,
test_feature_data=test_feature_data,
feature2count=feature2count,
feature2id=feature2id,
id2feature=id2feature,
# use_zen=args.zen_model!="",
# zen_model=zen_model,
# zen_dataset=zen_test_dataset
)
if warmup_steps>0:
warmup_callback = WarmupCallback(warmup_steps, schedule='linear')
callbacks.append(warmup_callback)
callbacks.extend([clip_callback, evaluate_callback])
8.Trainer
from fastNLP import SpanFPreRecMetric, BucketSampler
trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(),
num_workers=0, n_epochs=100, dev_data=data_bundle.get_dataset('dev'),
metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type),
dev_batch_size=batch_size, callbacks=callbacks, device=device, test_use_tqdm=False,
use_tqdm=True, print_every=300, save_path=save_path,
use_knowledge=True,
knowledge_type=args.knowledge_type,
pos_th=args.pos_th,
dep_th=args.dep_th,
chunk_th=args.chunk_th,
train_feature_data=train_feature_data,
test_feature_data=dev_feature_data,
feature2count=feature2count,
feature2id=feature2id,
id2feature=id2feature,
logger_func=write_log,
use_zen=args.zen_model!="",
zen_model=zen_model,
zen_train_dataset=zen_train_dataset,
zen_dev_dataset=zen_dev_dataset
)
trainer.train(load_best_model=False)