Java机器学习库ML之四模型训练和预测示例

153 篇文章 2 订阅
69 篇文章 120 订阅

基于ML库机器学习的步骤:

1)样本数据导入;

2)样本数据特征抽取和特征值处理(结合模型需要归一化或离散化);这里本文没有做处理,特征选择和特征值处理本身就很大;

3)样本集划分训练集和验证集;

4)根据训练集训练模型;

5)用验证集评价模型;

6)导入测试集,并用模型预测输出预测结果;

package com.vip;

import java.io.File;

import be.abeel.util.Pair;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.KNearestNeighbors;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.DenseInstance;
import net.sf.javaml.core.Instance;
import net.sf.javaml.featureselection.scoring.GainRatio;
import net.sf.javaml.sampling.Sampling;
import net.sf.javaml.tools.data.FileHandler;

public class VIPClassifer {
	 public static void main(String[] args)throws Exception {
		    if (args.length != 2) {
				System.err.println("Usage: 输入训练集和测试集路径");
				System.exit(2);
			}
	        /* Load a data set 前面13列是训练特征,最后1列标记*/
	        Dataset ori_data = FileHandler.loadDataset(new File(args[0]), 13, "\\s+");
            //特征评分,可独立
	        //GainRatio ga = new GainRatio();        
	        //ga.build(ori_data);  /* Apply the algorithm to the data set */    
	        //for (int i = 0; i < ga.noAttributes(); i++)  
	         //   System.out.println(ga.score(i));         
	        //抽样训练集和验证集
	        Sampling s = Sampling.SubSampling;
	        Pair<Dataset, Dataset> sam_data = s.sample(ori_data, (int) (ori_data.size() * 0.8));
	        /*
	        Dataset train_data = new DefaultDataset();//80%训练
	        Dataset test_data = new DefaultDataset();//20%验证
	        int sample=0;
	        for(Instance inst:ori_data){
	        	double[] values = new double[] { inst.value(5),inst.value(6),inst.value(7),inst.value(8), inst.value(9),inst.value(16),inst.value(17)};
	        	Instance train_inst = new DenseInstance(values, inst.classValue());        	
	        	if(sample<4){
	        		sample++;
	        		train_data.add(train_inst);
	        	}else {
	        	    sample=0;
	        		test_data.add(train_inst); 
	        	}       	
	        }*/
	        //Contruct a KNN classifier that uses 5 neighbors to make a decision.
	        Classifier knn = new KNearestNeighbors(5);
	        knn.buildClassifier(sam_data.x());
	        //验证集
	        int correct = 0, wrong = 0;
	        /* Classify all instances and check with the correct class values */
	        for (Instance inst : sam_data.y()) {
	            Object predictedClassValue = knn.classify(inst);
	            Object realClassValue = inst.classValue();
	            if (predictedClassValue.equals(realClassValue))
	                correct++;
	            else
	                wrong++;
	        }
	        System.out.println("Correct predictions  " + correct);
	        System.out.println("Wrong predictions " + wrong);
	        //模型预测
	        /* Load a data set 前面13列是训练特征,最后2列是uid和spuid联合标识*/
	        Dataset pre_data = FileHandler.loadDataset(new File(args[1]),"\\s+");
	        System.out.println(pre_data.instance(0));
	        Dataset out_data = new DefaultDataset();
	        for(Instance inst:pre_data){
	        	double[] values = new double[13]; 
	        	for(int i=0;i<13;i++) values[i]=inst.value(i);
	        	Instance pre_inst = new DenseInstance(values); //无标记,13列特征参与训练
	        	Object pre_classvalue = knn.classify(pre_inst);//预测结果
	        	//pre_inst.setClassValue(pre_classvalue);//标注预测结果
	        	double[] u_spu_id=new double[]{inst.value(13),inst.value(14)};
	        	Instance out_inst = new DenseInstance(u_spu_id,pre_classvalue); //带标记
	        	out_data.add(out_inst);
	        }
	        //输出u_Id+spu_id+action_type
	        FileHandler.exportDataset(out_data, new File("/data1/DataFountain/output.txt"));
	 }
}
//java -XX:-UseGCOverheadLimit -Xmx10240m -jar vip.jar train_features_new.txt test_features_new.txt

在上面这个代码框架内,可以用不同模型,如SVM、RF(随机森林)等,也可以对特征值进行处理后选择特征来训练。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
1、下载并安装mysql,将脚本执行至数据中; 2、配置java环境,使用jdk8,配置环境变量,下载IntelliJ IDEA 2019.2.4,该工具为java代码编译器 3、下载Maven,配置至环境变量(百度搜索很多),将构建器为Maven,类配置成阿里(方法:百度搜索很多很多) 4、将工程导入后,在application-local.yml文件中配置数据 5、在logback-prod.xml文件中配置log日志 6、配置完毕后,即可启动 访问地址:http://localhost:8082/anime/login.html 用户名:admin 密码:admin V:china1866 1、 登录 2、 首页 3、 权限管理-用户管理 4、 权限管理-添加用户数据 5、 交通数据管理-查看交通数据 6、 交通数据管理-添加交通数据 7、 交通预测-交通数据预测 脚本: CREATE TABLE `traffic_data_t` ( `id` INT(11) NOT NULL AUTO_INCREMENT COMMENT '序列', `trafficId` VARCHAR(50) NULL DEFAULT NULL COMMENT '交通数据编号', `trafficContent` VARCHAR(50) NULL DEFAULT NULL COMMENT '交通状况', `trafficSection` VARCHAR(200) NULL DEFAULT NULL COMMENT '交通路段', `trafficMan` VARCHAR(200) NULL DEFAULT NULL COMMENT '上报人', `trafficDate` VARCHAR(200) NULL DEFAULT NULL COMMENT '上报时间', `status` VARCHAR(200) NULL DEFAULT NULL COMMENT '交通状态', PRIMARY KEY (`id`) ) COMMENT='交通数据表' COLLATE='utf8_general_ci' ENGINE=InnoDB AUTO_INCREMENT=44 ; CREATE TABLE `sys_user_t` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `role_id` INT(11) NULL DEFAULT NULL COMMENT '角色ID', `user_id` VARCHAR(50) NOT NULL COMMENT '用户ID', `user_name` VARCHAR(100) NOT NULL COMMENT '用户名', `status` INT(11) NOT NULL COMMENT '是否有效0:false\\\\1:true', `create_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `create_by` VARCHAR(100) NULL DEFAULT NULL, `last_update_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `last_update_by` VARCHAR(100) NULL DEFAULT NULL, `password` VARCHAR(128) NOT NULL, `tenantcode` VARCHAR(50) NOT NULL, `diskId` VARCHAR(500) NULL DEFAULT NULL, `remarks` VARCHAR(500) NULL DEFAULT NULL, PRIMARY KEY (`id`) ) COMMENT='系统用户表' COLLATE='utf8_general_ci' ENGINE=InnoDB AUTO_INCREMENT=51 ; CREATE TABLE `sys_role_t` ( `role_id` INT(11) NOT NULL COMMENT '角色ID', `role_name` VARCHAR(200) NOT NULL COMMENT '权限名称', `status` INT(11) NOT NULL COMMENT '是否有效0:true\\\\1:false', `create_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `create_by` VARCHAR(100) NULL DEFAULT NULL, `last_update_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `last_update_by` VARCHAR(100) NULL DEFAULT NULL ) COMMENT='系统角色表' COLLATE='utf8_general_ci' ENGINE=InnoDB ; CREATE TABLE `sys_menu_t` ( `id` INT(11) NOT NULL AUTO_INCREMENT COMMENT '序列', `parent_id` VARCHAR(50) NOT NULL COMMENT '父节点ID', `menu_id` VARCHAR(50) NOT NULL COMMENT '菜单ID', `menu_name` VARCHAR(200) NOT NULL COMMENT '菜单名称', `menu_url` VARCHAR(200) NULL DEFAULT NULL COMMENT '菜单URL', `status` INT(11) NOT NULL COMMENT '有效(0有效,1失效)', `create_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `create_by` VARCHAR(200) NULL DEFAULT NULL, `last_update_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `last_update_by` VARCHAR(200) NULL DEFAULT NULL, PRIMARY KEY (`id`) ) COMMENT='菜单表' COLLATE='utf8_general_ci' ENGINE=InnoDB AUTO_INCREMENT=33 ; CREATE TABLE `sys_menu_role_relation_t` ( `id` INT(11) NOT NULL AUTO_INCREMENT COMMENT '序列', `menu_id` VARCHAR(50) NOT NULL COMMENT '菜单ID', `role_id` VARCHAR(50) NOT NULL COMMENT '角色ID', `status` INT(11) NOT NULL COMMENT '有效(0有效,1失效)', `create_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `create_by` VARCHAR(200) NULL DEFAULT NULL, `last_update_date` TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP, `last_update_by` VARCHAR(200) NULL DEFAULT NULL, PRIMARY KEY (`id`) ) COMMENT='角色与菜单关系表' COLLATE='utf8_general_ci' ENGINE=InnoDB AUTO_INCREMENT=51 ;
当然可以!以下是一个简单的FlinkML线性回归的示例代码: ``` // 导入所需的类 import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.CsvReader; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.common.LabeledVector; import org.apache.flink.ml.math.DenseVector; import org.apache.flink.ml.math.Vector; import org.apache.flink.ml.regression.LinearRegression; import org.apache.flink.ml.regression.LinearRegressionModel; // 创建执行环境 ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // 读取数据集 CsvReader reader = env.readCsvFile("path/to/dataset.csv") .fieldDelimiter(",") .ignoreFirstLine() .ignoreInvalidLines() .includeFields(true, true, true); // 将数据集转换为 LabeledVector 格式 DataSet<LabeledVector> data = reader.map(new MapFunction<Tuple3<Double, Double, Double>, LabeledVector>() { @Override public LabeledVector map(Tuple3<Double, Double, Double> value) throws Exception { double label = value.f0; Vector features = new DenseVector(new double[]{value.f1, value.f2}); return new LabeledVector(label, features); } }); // 创建线性回归模型 LinearRegression linearRegression = new LinearRegression() .setIterations(10) // 设置迭代次数 .setStepsize(0.5) // 设置步长 .setConvergenceThreshold(0.001); // 设置收敛阈值 // 训练模型 LinearRegressionModel model = linearRegression.fit(data); // 预测 Vector prediction = model.predict(new DenseVector(new double[]{1.0, 2.0})); // 输出预测结果 System.out.println("Prediction: " + prediction); ``` 这段代码实现了一个简单的线性回归模型,将数据集读取为 LabeledVector 格式,然后使用 LinearRegression 类创建模型训练。最后,使用模型进行预测并输出结果。如果需要使用其他模型,只需更改模型的类型即可。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值