第一阶段
第二阶段
nn.utils.clip_grad_norm_()官方文档介绍
第三阶段
准备工作
并行实现
多线程
1)相关函数
2)踩坑记录
RuntimeError: Cannot re-initialize CUDA in forked subprocess
RuntimeError: context has already been set(multiprocessing)
3)相关代码示例
法一:使用map_async()函数
# 单进程处理函数
def single_sentence(device, model, enc_vocab2id, dec_id2vocab, dec_vocab2id, sentence):
line = sentence
...
return bleu_score_1, bleu_score_2, bleu_score_3, bleu_score_4
#主函数里调用多进程
if __name__ == '__main__':
...
with torch.no_grad():
test_lines = open(test_file_path, 'r', encoding='utf-8').readlines()
test_size = len(test_lines)
pool = Pool(2) #并行进程数
par_single_sentence = partial(single_sentence, device, model, enc_vocab2id, dec_id2vocab, dec_vocab2id)
results = pool.map_async(par_single_sentence, test_lines)
for res in results.get():
all_bleu_score_1 += res[0]
all_bleu_score_2 += res[1]
all_bleu_score_3 += res[2]
all_bleu_score_4 += res[3]
pool.close()
pool.join()
# 计算所有句子的平均bleu得分
avg_bleu_score_1 = all_bleu_score_1 / test_size
avg_bleu_score_2 = all_bleu_score_2 / test_size
avg_bleu_score_3 = all_bleu_score_3 / test_size
avg_bleu_score_4 = all_bleu_score_4 / test_size
...
结果:(非并行左,并行右,并行数 = 2)
法二:使用apply_async()函数
# 单进程处理函数
def single_sentence(device, model, enc_vocab2id, dec_id2vocab, dec_vocab2id, sentence):
line = sentence
...
return bleu_score_1, bleu_score_2, bleu_score_3, bleu_score_4
#主函数里调用多进程
if __name__ == '__main__':
...
with torch.no_grad():
test_lines = open(test_file_path, 'r', encoding='utf-8').readlines()
test_size = len(test_lines)
pool = Pool(2) #并行进程数
for line in test_lines:
par_single_sentence = partial(single_sentence, device, model, enc_vocab2id, dec_id2vocab, dec_vocab2id)
res = pool.apply_async(par_single_sentence, test_lines)
all_bleu_score_1 += res[0]
all_bleu_score_2 += res[1]
all_bleu_score_3 += res[2]
all_bleu_score_4 += res[3]
pool.close()
pool.join()
# 计算所有句子的平均bleu得分
avg_bleu_score_1 = all_bleu_score_1 / test_size
avg_bleu_score_2 = all_bleu_score_2 / test_size
avg_bleu_score_3 = all_bleu_score_3 / test_size
avg_bleu_score_4 = all_bleu_score_4 / test_size
...