java调用R语言实现神经网络

java调用R语言相关函数,本文使用Rserve方法,需要在R平台安装install.packages(‘Rserve’),然后启动,library(Rserve),Rserve()。Windows上运行Rserve有一定的局限性,开发中需要注意的是,Rserve同时只允许一个客户端连接。因此,如果第二个线程试图连接时,它就会一直处在等待状态。

java工程中需要REngine.jar和RserveEngine.jar两个文件

RConnection c = new RConnection();

c.assgin(),c.eval()两个方法进行赋值和调用R相关函数

import org.rosuda.REngine.REXPMismatchException;
import org.rosuda.REngine.REngineException;
import org.rosuda.REngine.Rserve.RConnection;

public class Demo2 {
	public static void main(String[] args) throws REXPMismatchException,
			REngineException {
		RConnection c = new RConnection();//调用R
		
	//	REXP x = c.eval("R.version.string");
	//	System.out.println(x.asString());//R语言版本信息

		/**
		 * R语言矩阵赋值给java二维数组,可以直接定义java二维数组
		 */
		double p[][] = c
				.eval("matrix(c(6977.93,24647,11356.6,9772.5,1496.92,4279.65,89.84,95.97,9194,0.6068,7973.37,28534,13469.77,11585.82,1618.27,5271.991,100.28,111.16,9442,0.63,9294.26,33272,16004.61,14076.83,1707.98,6341.86,117.78,130.22,9660,0.6314,10868.67,37638,18502.2,16321.46,1790.97,6849.688,134.77,125.56,9893,0.6337,12933.12,39436,19419.7,18052.59,1855.73,6110.941,86.04,119.81,10130,0.634,15623.7,44736,23014.53,20711.55,1948.06,7848.961,151.59,187.08,10441,0.6618,17069.2,50807,26447.38,24097.7,2006.92,9134.673,177.79,202.12,10505,0.665,18751.47,54095,27700.97,26519.69,2037.88,9840.205,195.18,282.05,10594,0.674,21169.7,60633.82,31941.45,29569.92,2211.6665,11221.01,205.5601,329.4234,10986.79,0.684065,23716.17,66750.29,35562.93,32993.75,2317.9223,12486.77,220.3005,398.7751,11245.69,0.694706,26469.74,73292.95,39458.17,36680.63,2428.5869,13849.68,235.0408,477.4204,11515.33,0.706087),11,10,byrow=T)")
				.asDoubleMatrix();
		/*for (int i = 0; i < 11; i++) {
			for (int j = 0; j < 10; j++) {
				System.out.print(p[i][j] + " ");
			}
			System.out.println();
		}*/

		/**
		 * 矩阵归一化
		 */
		double p0[][] = new double[11][10];
		for (int i = 0; i < 10; i++) {
			double max = p[0][i];
			double min = p[0][i];
			for (int jj = 1; jj < 11; jj++) {
				if (p[jj][i] > max)
					max = p[jj][i];
				if (p[jj][i] < min)
					min = p[jj][i];
			}
			double a = max - min;
			for (int j = 0; j < 11; j++) {
				p0[j][i] = (p[j][i] - min) / a;
			}
		}

		/**
		 * 归一化后的矩阵转化成字符串作为参数,方便调用R函数
		 */
		StringBuilder chulihou = new StringBuilder();
		for (int i = 0; i < 11; i++) {
			for (int j = 0; j < 10; j++) {
				chulihou.append(p0[i][j] + ",");
			}
		}

		/**
		 * 归一化后的矩阵赋值给p0
		 */
		c.assign(
				"p0",
				c.eval("matrix(c("
						+ chulihou.substring(0, chulihou.length() - 1)
						+ "),11,10,byrow=T)"));

		/**
		 * 输出归一化后的数据
		 */
		/*for (int j = 1; j < c.eval("p0[,1]").asDoubles().length + 1; j++) {
			for (int i = 1; i < c.eval("p0[1,]").asDoubles().length + 1; i++) {
				System.out.print(c.eval("p0[" + j + "," + i + "]").asDouble()
						+ " ");
			}
			System.out.println();
		}*/

		/**
		 * 测试数据归一化,并将归一化后的结果赋值给tt0
		 */
		double t[] = { 2673.5356, 2991.0529, 3393.0057, 3504.8229, 3609.4029,
				4060.1257, 4399.0168, 4619.4102 };
		double t0[] = new double[t.length];

		double tmax = t[0];
		double tmin = t[0];
		for (int i = 1; i < t.length; i++) {
			if (t[i] > tmax)
				tmax = t[i];
			if (t[i] < tmin)
				tmin = t[i];
		}
		for (int i = 0; i < t.length; i++) {
			t0[i] = (t[i] - tmin) / (tmax - tmin);
		}
		c.assign("tt0", t0);

		/**
		 * 输出测试数据归一化后的结果
		 */
		/*for (int i = 1; i < c.eval("tt0").asDoubles().length + 1; i++) {
			System.out.print(c.eval("tt0[" + i + "]").asDouble() + " ");
		}*/
		
		int count = 0;//训练次数
		double alter = 1;//训练误差
		double betteralter = 1;//训练误差用途:若训练次数大于20次,选取20次中最接近真实值的预测值
		double tt = 4830.1315;//2013年用电量作为校验数据
		double[] y = new double[3];//存放预测值
		double[] y0 = new double[3];//存放更好的预测值
		while (Math.abs(alter) > 0.03 && count < 20) {
			/**
			 * 神经网络核心代码部分
			 */
			c.assign("net",c.eval("newff(n.neurons = c(10,10,2,1),learning.rate.global=1e-4, momentum.global=0.05,error.criterium=\"LMS\", Stao=NA, hidden.layer=\"tansig\", output.layer=\"purelin\", method=\"ADAPTgdwm\")"));
			c.assign("result", c.eval("train(net,p0[1:8,],tt0[1:8],error.criterium=\"LMS\", report=TRUE, show.step=10000, n.shows=5)"));
			y = c.eval("sim(result$net,p0[9:11,])").asDoubles();
			
			/**
			 * 反归一化
			 */
			for(int i=0;i<y.length;i++){
				y[i]=y[i]*t[7];
			}
			
			alter = (y[0]-tt)/tt;//校验误差
			
			if(Math.abs(alter) < betteralter){//当前训练结果比较好,将值赋给y0
				betteralter = alter;
				for(int i=0;i<y0.length;i++){
					y0[i]=y[i];
				}
			}
			
			count++;
		}
		
		/**
		 * 输出预测结果
		 */
		String[] year = {"2013年的用电量预测: ","2014年的用电量预测: ","2015年的用电量预测: "}; 
		System.out.println("训练次数为:"+count);
		for(int i=0;i<y0.length;i++){
			System.out.print(year[i]);
			System.out.println(y0[i]);
		} 
	}
}


阅读更多
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiyang_1/article/details/51881985
个人分类: 数据挖掘
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭