绪论
deeplearning4j提供了一个英文文本分类的例子,虽说是英文,但是中文分词之后,依葫芦画瓢用于训练,首先是利用用word2vec生成词向量模型,参考上面一篇提到的生成代码。然后是构建cnn用于文本分类。
训练数据格式:
1 小额贷款 要 吗
1 商家 @ 回应 十里洋场 江景会 所 订房 经理 电话 ️ 五星级 餐饮 加 五星级 k歌 包房 体验 打 8折 送 酒
1 可以 啊 微信同号
1 看头像 加微信
1 看头像 加微信
1 您 可以 加研欧 书院 微信公众号 里面 有 具体 收费 标 淮 本月 25日 书院 开展 试听 公开课 如果 您 有时间 可以来 试听
0 强烈推荐 4号 4号 店长 非常 满意 朋友 发型 合适 我的 开心 因为 自然 卷 每年 得 做 头发 去过 不少 理发店 接触 不少 发型师 到 目前为止 觉得 4号 ( 帅哥 一枚 ) 最 称心如意 会 根据 发质 情况 自己的 想法 给你 设计 适合 发型 不仅 专业 而且 非常 负责 细心 我 最喜欢 细心 人 因为 只有 这样 会 做出 好的 效果 这次 先 把 发型 做出来 4号 店长 建议 染色 那样 效果 会 更好 我 不用 那么 辛苦 坐 太久 贴心 是不是 总之 满意 染色 后的 效果 更 重要 4号 亲自动手 帮 我 做 发型 强烈推荐 大家 过来 找 4号 店长 下次 来 找 你 哦 4号 店长 支持 你 djehfjdushdjd
0 非常感谢 老师 给 小高 5 星 点评 想起 我 老师 一起 游览 情景 历历在目 时间 有时 像 一个 小偷 不知不觉 四月 已经 过完 2017年 已经 过去 三分之一 像 老话 连雨 不知 春 一 晴 方知夏 深 还没 来得及 好好 享受 温熏 春光 夏日 风 已 从 远处 徐徐 而来 夏天 青岛 最 美的 青岛 避暑 理想 之 地 欢迎 老师 有机会 夏天 再来 青岛 我的 手机号 即 微信号 到时 记得 微信 小高 我 来 帮 您 订 酒店 青岛 小高 祝 老师 身体健康 阖家欢乐 万事如意
0 我 来 时候 他家 根本 拒绝 兑换 也是 遇 得到 气死人
0 商家回应 尊敬 客人 感谢您 抽出 宝贵 时间 给 我们 评价 心瑞 国际 月子会所 来源于 台湾 拥有 26年 台式 母婴护理 经验 聚集 最 专业 最 精湛 台湾 护理 技术 团队 精心 定制 专属 护理服务 秉承 规范 操作 细致入微 精益求精 服务理念 为 产 后妈 咪 提供 从 饮食 护理 康复 早教 心理健康 等 全方位 贴身 服务模式 给 孕产 期间 家庭 全方位 专业 照护 舒适 体验 会所 紧邻 国内 最好 医院 协和医院 产后 妈咪 宝宝 都有 坚实 医疗保障 我们 免费提供 健身会所 游泳馆 给 入住 客人 家属 使用 , 我们 会 不定期 举办 丰富多彩 活动 让 更多 孕妈咪 们 了解 孕期 保健知识 新生儿 喂养 知识 哦 非常 期待 下次 妈咪 见面 哦 心瑞 国际 月子会所 全体员工 祝 服务热线 座机号码 请关注 微信号 微信账号
0 商家回应 亲爱 贵宾 感谢您 襄阳 巴厘岛 休闲 度假酒店 肯定 支持 酒店 休闲会所 主要 提供 休闲 洗浴 游泳 桑拿 干湿 蒸 足疗 按摩 等 项目 如此 次 体验 没能 让 您 满意 我们 表示 深深 歉意 我们 足疗 专业 休闲 手法 所有 技师 素质 高 够 专业 各 具 魅力 服务 更好 巴厘岛 一切 以 您 最 舒适 休闲 方式 为先 您 满意 我们 继续 进步 动力 酒店 全体员工 期待 您 再次光临
0 人 直接 不行 必须 每人 点 一份 主食 我 那 钱 我 给你 东西 不用 上了 吃 不掉 浪费 服务员 b 那 再 来个 茶 吧 饭 毕 服务员 c 过来 我们 要 续 杯茶 他 可能 听懂 把 bill 拿走 一会 送来 一个新 上面 加 一杯 茶 我们 收费 那 不要 他 直接 气呼呼 拿走 新 bill 回来 时候 把 新 打印 去掉 杯茶 bill 直接 摔 桌子 然后 回头 走 我 擦 你 真 牛逼 我 走 这么多 英联邦 国家 地区 别人 一看 中国 人 客客气气 何况 还是 顾客 你 妹的
0 维权 群 怎么 加 我 被 套路
训练代码:
LabeledSentence.java
package com.dianping.cnn.textclassify;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.datavec.api.util.RandomUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.nd4j.linalg.collection.CompactHeapStringList;
public class LabeledSentence implements LabeledSentenceProvider {
private int totalCount;
private Map<String, List<String>> filesByLabel;
private List<String> normList;
private List<String> negList;
private final List<String> sentenslist;
private final int[] labelIndexes;
private final Random rng;
private final int[] order;
private final List<String> allLabels;
private int cursor = 0;
public LabeledSentence(String path) {
this(path, new Random());
}
public LabeledSentence(String path, Random rng) {
totalCount = 0;
filesByLabel = new HashMap<String, List<String>>();
normList = new ArrayList<String>();
negList = new ArrayList<>();
BufferedReader buffered = null;
try {
buffered = new BufferedReader(new InputStreamReader(
new FileInputStream(path)));
String line = buffered.readLine();
while (line != null) {
String[] lines = line.split("\t");
String label = lines[0];
String contennt = lines[1];
if ("1".equalsIgnoreCase(label)) {
normList.add(contennt);
} else if("0".equalsIgnoreCase(label)) {
negList.add(contennt);
}
totalCount++;
line = buffered.readLine();
}
buffered.close();
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("totalCount is:"+totalCount);
filesByLabel.put("1", normList);
filesByLabel.put("0", negList);
this.rng = rng;
if (rng == null) {
order = null;
} else {
order = new int[totalCount];
for (int i = 0; i < totalCount; i++) {
order[i] = i;
}
RandomUtils.shuffleInPlace(order, rng);
}
allLabels = new ArrayList<>(filesByLabel.keySet());
Collections.sort(allLabels);
Map<String, Integer> labelsToIdx = new HashMap<>();
for (int i = 0; i < allLabels.size(); i++) {
labelsToIdx.put(allLabels.get(i), i);
}
sentenslist = new CompactHeapStringList();
labelIndexes = new int[totalCount];
int position = 0;
for (Map.Entry<String, List<String>> entry : filesByLabel.entrySet()) {
int labelIdx = labelsToIdx.get(entry.getKey());
for (String f : entry.getValue()) {
sentenslist.add(f);
labelIndexes[position] = labelIdx;
position++;
}
}
}
@Override
public boolean hasNext() {
return cursor < totalCount;
}
@Override
public Pair<String, String> nextSentence() {
int idx;
if (rng == null) {
idx = cursor++;
} else {
idx = order[cursor++];
}
;
String label = allLabels.get(labelIndexes[idx]);
String sentence;
sentence = sentenslist.get(idx);
return new Pair<>(sentence, label);
}
@Override
public void reset() {
cursor = 0;
if (rng != null) {
RandomUtils.shuffleInPlace(order, rng);
}
}
@Override
public int totalNumSentences() {
return totalCount;
}
@Override
public List<String> allLabels() {
return allLabels;
}
@Override
public int numLabelClasses() {
return allLabels.size();
}
}
CnnSentenceDataSetIterator.java
package com.dianping.cnn.textclassify;
import lombok.AllArgsConstructor;
import lombok.NonNull;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.*;
public class CnnSentenceDataSetIterator implements DataSetIterator {
public enum UnknownWordHandling {
RemoveWord, UseUnknownVector
}
private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
private LabeledSentenceProvider sentenceProvider = null;
private WordVectors wordVectors;
private TokenizerFactory tokenizerFactory;
private UnknownWordHandling unknownWordHandling;
private boolean useNormalizedWordVectors;
private int minibatchSize;
private int maxSentenceLength;
private boolean sentencesAlongHeight;
private DataSetPreProcessor dataSetPreProcessor;
private int wordVectorSize;
private int numClasses;
private Map<String, Integer> labelClassMap;
private INDArray unknown;
private int cursor = 0;
private CnnSentenceDataSetIterator(Builder builder) {
this.sentenceProvider = builder.sentenceProvider;
this.wordVectors = builder.wordVectors;
this.tokenizerFactory = builder.tokenizerFactory;
this.unknownWordHandling = builder.unknownWordHandling;
this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
this.minibatchSize = builder.minibatchSize;
this.maxSentenceLength = builder.maxSentenceLength;
this.sentencesAlongHeight = builder.sentencesAlongHeight;
this.dataSetPreProcessor = builder.dataSetPreProcessor;
this.numClasses = this.sentenceProvider.numLabelClasses();
this.labelClassMap = new HashMap<>();
int count = 0;
//First: sort the labels to ensure the same label assignment order (say train vs. test)
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
Collections.sort(sortedLabels);
for (String s : sortedLabels) {
this.labelClassMap.put(s, count++);
}
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
if (useNormalizedWordVectors) {
wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
} else {
wordVectors.getWordVectorMatrix(wordVectors.getUNK());
}
}
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
}
/**
* Generally used post training time to load a single sentence for predictions
*/
public INDArray loadSingleSentence(String sentence) {
List<String> tokens = tokenizeSentence(sentence);
int[] featuresShape = new int[] {1, 1, 0, 0};
if (sentencesAlongHeight) {
featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
featuresShape[3] = wordVectorSize;
} else {
featuresShape[2] = wordVectorSize;
featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
}
INDArray features = Nd4j.create(featuresShape);
int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
for (int i = 0; i < length; i++) {
INDArray vector = getVector(tokens.get(i));
INDArrayIndex[] indices = new INDArrayIndex[4];
indices[0] = NDArrayIndex.point(0);
indices[1] = NDArrayIndex.point(0);
if (sentencesAlongHeight) {
indices[2] = NDArrayIndex.point(i);
indices[3] = NDArrayIndex.all();
} else {
indices[2] = NDArrayIndex.all();
indices[3] = NDArrayIndex.point(i);
}
features.put(indices, vector);
}
return features;
}
private INDArray getVector(String word) {
INDArray vector;
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL) { //Yes, this *should* be using == for the sentinel String here
vector = unknown;
} else {
if (useNormalizedWordVectors) {
vector = wordVectors.getWordVectorMatrixNormalized(word);
} else {
vector = wordVectors.getWordVectorMatrix(word);
}
}
return vector;
}
private List<String> tokenizeSentence(String sentence) {
Tokenizer t = tokenizerFactory.create(sentence);
List<String> tokens = new ArrayList<>();
while (t.hasMoreTokens()) {
String token = t.nextToken();
if (!wordVectors.hasWord(token)) {
switch (unknownWordHandling) {
case RemoveWord:
continue;
case UseUnknownVector:
token = UNKNOWN_WORD_SENTINEL;
}
}
tokens.add(token);
}
return tokens;
}
public Map<String, Integer> getLabelClassMap() {
return new HashMap<>(labelClassMap);
}
@Override
public List<String> getLabels() {
//We don't want to just return the list from the LabelledSentenceProvider, as we sorted them earlier to do the
// String -> Integer mapping
String[] str = new String[labelClassMap.size()];
for (Map.Entry<String, Integer> e : labelClassMap.entrySet()) {
str[e.getValue()] = e.getKey();
}
return Arrays.asList(str);
}
@Override
public boolean hasNext() {
if (sentenceProvider == null) {
throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
}
return sentenceProvider.hasNext();
}
@Override
public DataSet next() {
return next(minibatchSize);
}
@Override
public DataSet next(int num) {
if (sentenceProvider == null) {
throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
}
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
int maxLength = -1;
int minLength = Integer.MAX_VALUE; //Track to we know if we can skip mask creation for "all same length" case
for (int i = 0; i < num && sentenceProvider.hasNext(); i++) {
Pair<String, String> p = sentenceProvider.nextSentence();
List<String> tokens = tokenizeSentence(p.getFirst());
maxLength = Math.max(maxLength, tokens.size());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
}
if (maxSentenceLength > 0 && maxLength > maxSentenceLength) {
maxLength = maxSentenceLength;
}
int currMinibatchSize = tokenizedSentences.size();
INDArray labels = Nd4j.create(currMinibatchSize, numClasses);
for (int i = 0; i < tokenizedSentences.size(); i++) {
String labelStr = tokenizedSentences.get(i).getSecond();
if (!labelClassMap.containsKey(labelStr)) {
throw new IllegalStateException("Got label \"" + labelStr
+ "\" that is not present in list of LabeledSentenceProvider labels");
}
int labelIdx = labelClassMap.get(labelStr);
labels.putScalar(i, labelIdx, 1.0);
}
int[] featuresShape = new int[4];
featuresShape[0] = currMinibatchSize;
featuresShape[1] = 1;
if (sentencesAlongHeight) {
featuresShape[2] = maxLength;
featuresShape[3] = wordVectorSize;
} else {
featuresShape[2] = wordVectorSize;
featuresShape[3] = maxLength;
}
INDArray features = Nd4j.create(featuresShape);
for (int i = 0; i < currMinibatchSize; i++) {
List<String> currSentence = tokenizedSentences.get(i).getFirst();
for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) {
INDArray vector = getVector(currSentence.get(j));
INDArrayIndex[] indices = new INDArrayIndex[4];
//TODO REUSE
indices[0] = NDArrayIndex.point(i);
indices[1] = NDArrayIndex.point(0);
if (sentencesAlongHeight) {
indices[2] = NDArrayIndex.point(j);
indices[3] = NDArrayIndex.all();
} else {
indices[2] = NDArrayIndex.all();
indices[3] = NDArrayIndex.point(j);
}
features.put(indices, vector);
}
}
INDArray featuresMask = null;
if (minLength != maxLength) {
featuresMask = Nd4j.create(currMinibatchSize, maxLength);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.getRow(i).assign(1.0);
} else {
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
}
}
}
DataSet ds = new DataSet(features, labels, featuresMask, null);
if (dataSetPreProcessor != null) {
dataSetPreProcessor.preProcess(ds);
}
cursor += ds.numExamples();
return ds;
}
@Override
public int totalExamples() {
return sentenceProvider.totalNumSentences();
}
@Override
public int inputColumns() {
return wordVectorSize;
}
@Override
public int totalOutcomes() {
return numClasses;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
cursor = 0;
sentenceProvider.reset();
}
@Override
public int batch() {
return minibatchSize;
}
@Override
public int cursor() {
return cursor;
}
@Override
public int numExamples() {
return totalExamples();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.dataSetPreProcessor = preProcessor;
}
@Override
public DataSetPreProcessor getPreProcessor() {
return dataSetPreProcessor;
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported");
}
public static class Builder {
private LabeledSentenceProvider sentenceProvider = null;
private WordVectors wordVectors;
private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
private boolean useNormalizedWordVectors = true;
private int maxSentenceLength = -1;
private int minibatchSize = 32;
private boolean sentencesAlongHeight = true;
private DataSetPreProcessor dataSetPreProcessor;
/**
* Specify how the (labelled) sentences / documents should be provided
*/
public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
this.sentenceProvider = labeledSentenceProvider;
return this;
}
/**
* Specify how the (labelled) sentences / documents should be provided
*/
public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List<String> labels) {
LabelAwareConverter converter = new LabelAwareConverter(iterator, labels);
return sentenceProvider(converter);
}
/**
* Specify how the (labelled) sentences / documents should be provided
*/
public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List<String> labels) {
DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator);
return sentenceProvider(converter, labels);
}
/**
* Specify how the (labelled) sentences / documents should be provided
*/
public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List<String> labels) {
SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator);
return sentenceProvider(converter, labels);
}
/**
* Provide the WordVectors instance that should be used for training
*/
public Builder wordVectors(WordVectors wordVectors) {
this.wordVectors = wordVectors;
return this;
}
/**
* The {@link TokenizerFactory} that should be used. Defaults to {@link DefaultTokenizerFactory}
*/
public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
/**
* Specify how unknown words (those that don't have a word vector in the provided WordVectors instance) should be
* handled. Default: remove/ignore unknown words.
*/
public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
this.unknownWordHandling = unknownWordHandling;
return this;
}
/**
* Minibatch size to use for the DataSetIterator
*/
public Builder minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize;
return this;
}
/**
* Whether normalized word vectors should be used. Default: true
*/
public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
this.useNormalizedWordVectors = useNormalizedWordVectors;
return this;
}
/**
* Maximum sentence/document length. If sentences exceed this, they will be truncated to this length by
* taking the first 'maxSentenceLength' known words.
*/
public Builder maxSentenceLength(int maxSentenceLength) {
this.maxSentenceLength = maxSentenceLength;
return this;
}
/**
* If true (default): output features data with shape [minibatchSize, 1, maxSentenceLength, wordVectorSize]<br>
* If false: output features with shape [minibatchSize, 1, wordVectorSize, maxSentenceLength]
*/
public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
this.sentencesAlongHeight = sentencesAlongHeight;
return this;
}
/**
* Optional DataSetPreProcessor
*/
public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
this.dataSetPreProcessor = dataSetPreProcessor;
return this;
}
public CnnSentenceDataSetIterator build() {
if (wordVectors == null) {
throw new IllegalStateException(
"Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
}
return new CnnSentenceDataSetIterator(this);
}
}
}
TrainAdxCnnModel.java
package com.dianping.cnn.textclassify;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class TrainAdxCnnModel {
public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
String WORD_VECTORS_PATH = "adx/word2vec.model";
// 基础配置
int batchSize = 10;
int vectorSize = 100; // 词典向量的维度,这边是100
int nEpochs =3; // 迭代代数
int truncateReviewsToLength = 256; // 词长大于256则抛弃
int cnnLayerFeatureMaps = 100; // 卷积神经网络特征图标 / channels / CNN每层layer的深度
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(100); // 随机抽样
// 设置网络配置->我们有多个卷积层,每个带宽3,4,5的滤波器
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.RELU)
.activation(Activation.LEAKYRELU)
.updater(Updater.ADAM)
.convolutionMode(ConvolutionMode.Same) //This is important so we can 'stack' the results later
.regularization(true).l2(0.0001)
.learningRate(0.01)
.graphBuilder()
.addInputs("input")
.addLayer("cnn3", new ConvolutionLayer.Builder()
.kernelSize(3,vectorSize)
.stride(1,vectorSize)
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addLayer("cnn4", new ConvolutionLayer.Builder()
.kernelSize(4,vectorSize)
.stride(1,vectorSize)
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addLayer("cnn5", new ConvolutionLayer.Builder()
.kernelSize(5,vectorSize)
.stride(1,vectorSize)
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5") //Perform depth concatenation
.addLayer("globalPool", new GlobalPoolingLayer.Builder()
.poolingType(globalPoolingType)
.build(), "merge")
.addLayer("out", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(3*cnnLayerFeatureMaps)
.nOut(2) //2 classes: positive or negative
.build(), "globalPool")
.setOutputs("out")
.build();
ComputationGraph net = new ComputationGraph(config);
net.init();
net.setListeners(new ScoreIterationListener(1));
// 加载向量字典并获取训练集合测试集的DataSetIterators
System.out
.println("Loading word vectors and creating DataSetIterators");
WordVectors wordVectors = WordVectorSerializer.fromPair(WordVectorSerializer.loadTxt(new File(WORD_VECTORS_PATH)));
DataSetIterator trainIter = getDataSetIterator(true, wordVectors,batchSize, truncateReviewsToLength, rng);
DataSetIterator testIter = getDataSetIterator(false, wordVectors,batchSize, truncateReviewsToLength, rng);
System.out.println("Starting training");
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
trainIter.reset();
// 进行网络演化(进化)获得网络判定参数
Evaluation evaluation = net.evaluate(testIter);
testIter.reset();
System.out.println(evaluation.stats());
}
// 训练之后:加载一个句子并输出预测
String contentsFirstPas = "我的 手机 是 手机号码";
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstPas);
INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
List<String> labels = testIter.getLabels();
System.out.println("\n\nPredictions for first negative review:");
for( int i=0; i<labels.size(); i++ ){
System.out.println("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
}
}
private static DataSetIterator getDataSetIterator(boolean isTraining,
WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
Random rng) {
String path = isTraining ? "adx/rnnsenec.txt" : "adx/rnnsenectest.txt";
LabeledSentenceProvider sentenceProvider = new LabeledSentence(path,
rng);
return new CnnSentenceDataSetIterator.Builder()
.sentenceProvider(sentenceProvider).wordVectors(wordVectors)
.minibatchSize(minibatchSize)
.maxSentenceLength(maxSentenceLength)
.useNormalizedWordVectors(false).build();
}
}
模型过程结果:
Loading word vectors and creating DataSetIterators
totalCount is:60
totalCount is:17
Starting training
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 0: 2 times
Examples labeled as 1 classified by model as 1: 6 times
==========================Scores========================================
Accuracy: 0.8824
Precision: 0.9091
Recall: 0.875
F1 Score: 0.8917
========================================================================
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 0: 1 times
Examples labeled as 1 classified by model as 1: 7 times
==========================Scores========================================
Accuracy: 0.9412
Precision: 0.95
Recall: 0.9375
F1 Score: 0.9437
========================================================================
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 1: 8 times
==========================Scores========================================
Accuracy: 1
Precision: 1
Recall: 1
F1 Score: 1
========================================================================
Predictions for first negative review:
P(0) = 0.4453294575214386
P(1) = 0.554670512676239
有问题联系我微信: xuxu_ge