学习过程:
近期通过查阅论文以及上网翻阅资料,查阅文献,我对完成课题所使用的神经网络模型推导和pytorch原理有了大致的了解。
主要关注了以下概念
深度学习之BP算法
梯度下降与随机梯度下降概念及推导过程:
Pytorch实现RNN:
RNN教程之-2 LSTM实战:
主要参考的文献:深度学习之自然语言处理进阶
项目进度:
目前我的主要工作为完成生成的问题集合的问题第一次筛选,主要实现思路为将参考答案和带标记的正确答案使用word2vec的分布式表示为空间中的n维向量,并以问题q为划分对象,从而达到筛选出优秀的q的作用。
表示过程:使用cbow模型
图示过程:
关键代码:
# 读入数据
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
contexts, target = create_contexts_target(corpus, window_size)
if config.GPU:
contexts, target = to_gpu(contexts), to_gpu(target)
# 生成模型等
model = CBOW(vocab_size, hidden_size, window_size, corpus)
# model = SkipGram(vocab_size, hidden_size, window_size, corpus)
optimizer = Adam()
trainer = Trainer(model, optimizer)
# 开始学习
trainer.fit(contexts, target, max_epoch, batch_size)
trainer.plot()
训练结果(部分):
day [-0.10517368 -0.11767567 -0.08770754 -0.08608673 0.1416603 0.12295089
0.12176713 -0.15600257 -0.13488823 0.10433038 -0.11382852 0.11108447
0.10454829 0.14350444 -0.0919579 -0.11685202 -0.11281923 0.10545972
-0.12588383 -0.13928771 -0.12421538 -0.12743862 0.14197709 0.13545138
0.09636388 0.14729802 0.12356224 -0.1310601 -0.11960482 -0.0868628
-0.1253302 -0.11179367 -0.10483316 0.13274033 0.12030896 -0.14628492
-0.13341272 0.11815492 0.16607009 0.1374148 0.11466035 0.11782282
-0.09025079 0.12452562 0.08582853 0.11918545 0.08405499 -0.07441092
0.12264467 0.12176208 0.10547844 0.11072478 -0.14566886 0.12585992
0.13482785 0.11935122 0.08624908 0.10331067 -0.12355839 0.10343292
0.12569971 -0.10155733 -0.11112247 -0.10061409 0.12013828 0.10144176
0.12023008 -0.13868822 0.11454513 0.09895245 0.11617471 -0.11939891
0.11349303 -0.13481392 -0.10593541 0.09760397 0.12053981 -0.10121962
-0.11299989 -0.11750894 -0.11038833 0.11607917 -0.11827555 0.12450801
-0.07779734 -0.11477651 0.12159468 0.10950866 0.12041226 0.11615779
0.11896573 0.11957716 0.12469604 0.11058014 -0.10453165 -0.1114757
0.1147315 0.1269861 0.1373257 0.12466712]
days [-0.18234362 -0.17784251 -0.14497586 -0.14459954 0.18304142 0.16626157
0.13798636 -0.18846454 -0.19211729 0.15492904 -0.15686871 0.15788968
0.17225228 0.18458408 -0.16689463 -0.12362264 -0.15635088 0.1672898
-0.16362582 -0.17262168 -0.15584452 -0.18600425 0.15711908 0.17349496
0.16788934 0.16680212 0.2016448 -0.17601967 -0.14748904 -0.16601637
-0.1636122 -0.18051462 -0.12198152 0.1764061 0.18702173 -0.17925723
-0.18451986 0.16599563 0.17094922 0.19764766 0.1548859 0.17285578
-0.1474535 0.18246648 0.13471271 0.18385014 0.14013349 -0.12104414
0.17972319 0.19462992 0.1564441 0.16953212 -0.15377274 0.18620114
0.16093694 0.1384015 0.11233116 0.17702718 -0.1724721 0.16897923
0.15310827 -0.18576346 -0.15209974 -0.14923184 0.16177008 0.17033686
0.17249553 -0.16442464 0.17219548 0.14163621 0.17466311 -0.16299516
0.15300398 -0.17007 -0.17571868 0.1432558 0.19443046 -0.14902169
-0.15417525 -0.158104 -0.16867223 0.17328519 -0.1458271 0.14065516
-0.16011505 -0.14400116 0.15340543 0.18147193 0.16297258 0.15671502
0.1834074 0.18648192 0.14913751 0.17229477 -0.16046892 -0.16286874
0.17249443 0.1598189 0.15457755 0.15544735]
longest [-0.11369815 -0.12043431 -0.092557 -0.09817025 0.0974712 0.11953922
0.09050267 -0.12989983 -0.12739237 0.09770411 -0.09124913 0.11923231
0.10500139 0.12321769 -0.10178895 -0.08414653 -0.11140335 0.11064935
-0.11910009 -0.13029186 -0.0819956 -0.14623053 0.10746644 0.11411982
0.07777556 0.11455971 0.10285854 -0.13111994 -0.09939577 -0.0969635
-0.09044521 -0.12611853 -0.07721555 0.13274965 0.1429023 -0.11630544
-0.12039573 0.0967311 0.10126024 0.09217947 0.09140108 0.12022021
-0.10746413 0.11575451 0.10332521 0.10784137 0.12165492 0.01202166
0.12645924 0.10078685 0.11007548 0.10907505 -0.12435152 0.11719389
0.11280338 0.1236036 0.07933119 0.10591829 -0.134402 0.1064928
0.09698504 -0.09272131 -0.11048459 -0.0654592 0.10098593 0.12799956
0.11465398 -0.1242158 0.09071227 0.11518865 0.11822289 -0.11066255
0.09791093 -0.10014951 -0.10658889 0.08308181 0.10707423 -0.11712457
-0.08655006 -0.08509547 -0.11665813 0.11491272 -0.10555442 0.11709879
-0.1214956 -0.11386202 0.12710302 0.130108 0.10406311 0.11363877
0.08845831 0.1340862 0.04878733 0.10403565 -0.11301355 -0.10326528
0.08017147 0.09825292 0.12335542 0.10292899]
此数据集为PTB数据集,待在自己的中文问题数据集上进行分布式表示和分类