package cn.hhb.spark.sql;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Created by dell on 2017/7/27.
*/
public class JDBCDataSource {
public static void main(String[] args) {
// 创建SparkConf
SparkConf conf = new SparkConf()
.setAppName("HiveDataSource").setMaster("local")
.set("spark.testing.memory", "2147480000");
// 创建javasparkContext
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
// 分别将mysql中两张表的数据加载为dataframe
Map options = new HashMap();
options.put("url","jdbc:mysql://spark1:3306/testdb");
options.put("dbtable","student_infos");
DataFrame studentInfosDF = sqlContext.read().format("jdbc").options(options).load();
options.put("dbtable","student_scores");
DataFrame studentScoresDF = sqlContext.read().format("jdbc").options(options).load();
// 将两个dataframe转换为javapairRDD,执行join操作
JavaPairRDD> studentsRDD =
studentInfosDF.javaRDD().mapToPair(new PairFunction() {
@Override
public Tuple2 call(Row row) throws Exception {
return new Tuple2(
row.getString(0),
Integer.valueOf(String.valueOf(row.getLong(1)))
);
}
}).join(studentScoresDF.javaRDD().mapToPair(new PairFunction() {
@Override
public Tuple2 call(Row row) throws Exception {
return new Tuple2(
String.valueOf(row.get(0)),
Integer.valueOf(String.valueOf(row.get(1)))
);
}
}));
// 将javapairRDD转换为javaRDD
JavaRDD studentRowsRDD = studentsRDD.map(new Function>, Row>() {
@Override
public Row call(Tuple2> tuple) throws Exception {
return RowFactory.create(tuple._1, tuple._2._1, tuple._2._2);
}
});
// 过滤出分数大于80分的数据
JavaRDD filteredStudentRowsRDD = studentRowsRDD.filter(new Function() {
@Override
public Boolean call(Row row) throws Exception {
if (row.getInt(2) > 80){
return null;
}
return false;
}
});
// 转换为dataframe
List structFields = new ArrayList();
structFields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
structFields.add(DataTypes.createStructField("score", DataTypes.IntegerType, true));
structFields.add(DataTypes.createStructField("age", DataTypes.IntegerType, true));
StructType structType = DataTypes.createStructType(structFields);
// 使用动态构造的元数据,将rdd转换为dataframe
DataFrame studentsDF = sqlContext.createDataFrame(filteredStudentRowsRDD, structType);
Row[] rows = studentsDF.collect();
for (Row row : rows){
System.out.println(row);
}
// 将dataFrame中的数据保存到mysql表中
studentsDF.javaRDD().foreach(new VoidFunction() {
@Override
public void call(Row row) throws Exception {
String sql = "insert into good_student_infos values('"+row.getString(0)+"','"+Integer.valueOf(String.valueOf(row.getLong(1)))+"','"+Integer.valueOf(String.valueOf(row.getLong(1)))+"')";
Class.forName("com.mysql.jdbc.Driver");
Connection conn = null;
Statement stmt = null;
try {
conn = DriverManager.getConnection(
"jdbc:mysql://spark1:3306/testdb",
"",
""
);
stmt = conn.createStatement();
stmt.executeUpdate(sql);
} catch (Exception e){
e.printStackTrace();
} finally {
if (stmt != null){
stmt.close();
}
if (conn != null){
conn.close();
}
}
}
});
sc.close();
}
}