在github上面找到了该论文的源码,地址如下
https://github.com/lc222/Dynamic-CNN-Sentence-Classification-TF
但是因为我的python版本是3.7,所以还需要一些改动的地方,如果大家有同样的问题,可以参考一下我的这篇博客:
1.print()的问题
作者原文中用的print()都是下面这种形式:
print 'Finish training. On test set:'
我们需要给他加上括号,结果如下:
print('Finish training. On test set:')
2.utf-8码的问题
原文中,直接进行如下操作
x_test = list(open(folder_prefix+"test").readlines())
我们需要给他加上rb
x_test = list(open(folder_prefix+"test", 'rb').readlines())
3.decode问题
原文中直接使用了clean_str函数,最终报错
x_text = [clean_str(sent) for sent in x_text]
改成:
#修改
le = len(x_text)
for i in range(le):
encode_type = chardet.detect(x_text[i])
x_text[i] = x_text[i].decode(encode_type['encoding']) # 进行相应解码,赋给原标识符(变量)
#修改
x_text = [clean_str(sent) for sent in x_text]
4.类型问题
原文中执行zip之后直接将其看成list类型进行操作:
batches = dataUtils.batch_iter(zip(x_train, y_train), batch_size, n_epochs)
我们需要对他强制转换一下:
batches = dataUtils.batch_iter(list(zip(x_train, y_train)), batch_size, n_epochs)
5.类型问题2
像/之后,Python都将其当成float类型,原文代码如下(会报错):
W2 = init_weights([ws[1], embed_dim/2, num_filters[0], num_filters[1]], "W2")
我们需要加int强制转换:
W2 = init_weights([ws[1], int(embed_dim/2), num_filters[0], num_filters[1]], "W2")
还有几处需要加int的地方,读者只要根据报错位置一一改正即可
最后运行train.py函数