构建卷积神经网络模型(CNN)讲解与实战

68 篇文章 0 订阅
20 篇文章 0 订阅

1 卷积神经网络模型

1.1 神经网络

神经网络可以指向两种,一个是生物神经网络,一个是人工神经网络。
生物神经网络 :一般指 生物 大脑神经元 细胞 触点 等组成的 网络 ,用于产生生物的
意识 ,帮助生物进行 思考 行动

一个神经元可以通过轴突作用于成千上万的神经元,也可以通过树突从成千上万的
神经元接受信息。上级神经元的轴突在有电信号传导时释放出化学递质,作用于下一级
神经元的树突,树突受到递质作用后产生出电信号,从而实现了神经元间的信息传递。
化学递质可以使下一级神经元兴奋或抑制。
人工神经网络 :人工神经网络( Artificial Neural Networks ,简写为 ANNs )也简称为神
经网络( NNs )或称作连接模型( Connection Model ),它是一种模仿动物神经网络行
为特征,进行分布式并行信息处理的算法数学模型。这种网络依靠系统的复杂程度,通
过调整内部大量节点之间相互连接的关系,从而达到处理信息的目的。
是一种应用类似于大脑神经突触联接的结构进行信息处理的数学模型。在工程与学术界
也常直接简称为 神经网络 或类神经网络。

1.2 卷积

英文中的 to convolve 词源为拉丁文 convolvere ,意为 卷在一起 。从数学角度说,卷
积是指用来计算一个函数通过另一个函数时,两个函数有多少重叠的积分。卷积可以视
为通过相乘的方式将两个函数进行混合。

2 代码实现

1)修改配置文件application.yml

ai:
#神经网络卷积模型
cnnModel: E:\\article.cnnmodel

2)添加工具类CnnUtil(资源中提供)

/**
* CNN工具类
*/
public class CnnUtil {
/**
* 创建计算图(卷积神经网络)
* @param cnnLayerFeatureMaps 卷积核的数量(=词向量维度)
* @return 计算图
*/
public static ComputationGraph createComputationGraph(int
cnnLayerFeatureMaps){
//训练模型
int vectorSize = 300; //向量大小
//int cnnLayerFeatureMaps = 100; 每种大小卷积层的卷积核的数
量=词向量维度
ComputationGraphConfiguration config = new
NeuralNetConfiguration.Builder()
.convolutionMode(ConvolutionMode.Same)// 设置卷积模式
.graphBuilder()
.addInputs("input")
.addLayer("cnn1", new ConvolutionLayer.Builder()//卷积层
.kernelSize(3,vectorSize)//卷积区域尺寸
.stride(1,vectorSize)//卷积平移步幅
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addLayer("cnn2", new ConvolutionLayer.Builder()
.kernelSize(4,vectorSize)
.stride(1,vectorSize)
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addLayer("cnn3", new ConvolutionLayer.Builder()
.kernelSize(5,vectorSize)
.stride(1,vectorSize)
.nIn(1)
.nOut(cnnLayerFeatureMaps)
.build(), "input")
.addVertex("merge", new MergeVertex(), "cnn1", "cnn2","cnn3")//全连接层
.addLayer("globalPool", new
GlobalPoolingLayer.Builder()//池化层
.build(), "merge")
.addLayer("out", new OutputLayer.Builder()//输出层
.nIn(3*cnnLayerFeatureMaps)
.nOut(3)
.build(), "globalPool")
.setOutputs("out")
.build();
ComputationGraph net = new ComputationGraph(config);
net.init();
return net;
}
/**
* 获取训练数据集
* @param path 分词语料库根目录
* @param childPaths 分词语料库子文件夹
* @param vecModel 词向量模型
* @return
*/
public static DataSetIterator getDataSetIterator(String path,
String[] childPaths, String vecModel ){
//加载词向量模型
WordVectors wordVectors =
WordVectorSerializer.loadStaticModel(new File(vecModel));
//词标记分类比标签
Map<String,List<File>> reviewFilesMap = new HashMap<>();
for( String childPath: childPaths){
reviewFilesMap.put(childPath, Arrays.asList(new
File(path+"/"+ childPath ).listFiles()));
}
//标记跟踪
LabeledSentenceProvider sentenceProvider = new
FileLabeledSentenceProvider(reviewFilesMap, new Random(12345));
return new CnnSentenceDataSetIterator.Builder()
.sentenceProvider(sentenceProvider)
.wordVectors(wordVectors)
.minibatchSize(32)
.maxSentenceLength(256)
.useNormalizedWordVectors(false)
.build();
}
public static Map<String, Double> predictions(String vecModel,String
cnnModel,String dataPath,String[] childPaths,String content) throws
IOException {
Map<String, Double> map = new HashMap<>();
//模型应用
ComputationGraph model =
ModelSerializer.restoreComputationGraph(cnnModel);//通过cnn模型获取计算图对
象
//加载数据集
DataSetIterator dataSet =
CnnUtil.getDataSetIterator(dataPath,childPaths, vecModel);
//通过句子获取概率矩阵对象
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)
dataSet).loadSingleSentence(content);
INDArray predictionsFirstNegative
=model.outputSingle(featuresFirstNegative);
List<String> labels = dataSet.getLabels();
for (int i = 0; i < labels.size(); i++) {
map.put(labels.get(i) + "",
predictionsFirstNegative.getDouble(i));
}
return map;
}
}

3)创建服务类CnnService

@Service
public class CnnService {
@Value("${ai.dataPath}")
private String dataPath; //合并前的分词语料库
@Value("${ai.vecModel}")
private String vecModel; //词向量模型
@Value("${ai.cnnModel}")
private String cnnModel;//卷积神经网络模型
/**
* 生成卷积神经网络模型
*/
public void build(){
try {
//1.创建计算图对象
ComputationGraph net = CnnUtil.createComputationGraph(100);
//2.加载训练数据集
String [] childPaths={"ai","db","web"};
DataSetIterator dataSet =
CnnUtil.getDataSetIterator(dataPath, childPaths, vecModel);
//3.训练模型
net.fit(dataSet);
//4.保存模型
new File(cnnModel).delete();
ModelSerializer.writeModel(net,cnnModel,true);
} catch (Exception e) {
e.printStackTrace();
}
}
}

4)修改任务类 TrainTask

package com.tensquare.ai.task;
import com.tensquare.ai.service.CnnService;
import com.tensquare.ai.service.Word2VecService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
/**
* 训练任务类
*/
@Component
public class TrainTask {
@Autowired
private Word2VecService word2VecService;
@Autowired
private CnnService cnnService;
/**
* 训练模型
*/
@Scheduled(cron="0 30 16 * * ?")
public void makeModel(){
System.out.println("开始合并语料库......");
word2VecService.mergeWord();
System.out.println("合并语料库结束‐‐‐‐‐‐");
System.out.println("开始构建词向量模型");
word2VecService.build();
System.out.println("构建词向量模型结束");
System.out.println("开始构建神经网络卷积模型");
cnnService.build();
System.out.println("构建神经网络卷积模型结束");
}
}

实现智能分类

3.1 需求分析

传入文本,得到所属分类信息

3.2 代码实现

1)修改CnnService,增加方法

/**
* 返回map集合 分类与百分比
* @param content
* @return
*/
public Map textClassify(String content) {
System.out.println("content:"+content);
//分词
try {
content=util.IKUtil.split(content," ");
} catch (IOException e) {
e.printStackTrace();
}
String[] childPaths={"ai","db","web"};
//获取预言结果
Map map = null;
try {
map = CnnUtil.predictions(vecModel, cnnModel, dataPath,
childPaths, content);
} catch (IOException e) {
e.printStackTrace();
}
return map;
}

2)创建AiController

package com.tensquare.ai.controller;
import com.tensquare.ai.service.CnnService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
@RestController
@RequestMapping("/cnn")
public class CnnController {
@Autowired
private CnnService cnnService;
@RequestMapping(value="/textclassify",method = RequestMethod.POST)
public Map textClassify( @RequestBody Map<String,String> content){
return cnnService.textClassify(content.get("content"));
}
}

3)使用postman测试 http://localhost:8080/cnn/textclassify 提交格式:

{
"content":"测试文本"
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

纵然间

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值