底层方法
package cn.kgc.java2scala;
import java.lang.reflect.Field;
import java.util.*;
public class DataFrame<T> {
//枚举模拟查询条件类型
enum Condition{
EQ,NE,GT,GE,LT,LE,LIKE,IN
}
//枚举模拟聚合类型
enum Agg{
COUNT,AVG,SUM,MIN,MAX,COLLECT_LIST,COLLECT_SET
}
//自定义类型模拟表结构列
static class ColHead{
private String name;
private Field field;
private Class type;
public ColHead(String name, Field field) {
this.name = name;
this.field = field;
}
public ColHead(String name, Class type) {
this.name = name;
this.type = type;
}
}
//Set集合存放数值类型信息:方便验证字段是否为数值类型
static Set<String> types;
//静态代码块:初始化数值类型信息
static {
types = new HashSet<>();
types.addAll(Arrays.asList(
"byte","short","int","long","float","double",
"java.lang.Byte","java.lang.Short","java.lang.Integer","java.lang.Long",
"java.lang.Float","java.lang.Double"));
}
//存储表结构:字段位置和字段的名称及类型映射
private Map<Integer,ColHead> schema;
//存储验证结构:字段名和字段位置的映射关系
private Map<String,Integer> validator;
//原始的二维数据
private List<Object[]> data;
//分组字段
private List<String> groupFields;
//分组数据
private Map<String,List<Object[]>> groupData;
//自定义方法:根据实体对象类型,通过反射初始化表结构信息
private void init(Class<T> schema){
if (null==schema) {
throw new NullPointerException("schema can't be NULL");
}
this.schema = new HashMap<>();
validator = new HashMap<>();
int count = 0;
for (Field f : schema.getDeclaredFields()) {
f.setAccessible(true);
this.schema.put(++count,new ColHead(f.getName(),f));
this.validator.put(f.getName(),count);
}
this.data = new ArrayList<>();
}
//自定义方法:填充单行原始数据
private void fill(Object[] row){
data.add(row);
}
//自定义方法:将实体对象转化为原始行添加
private void fill(T t) {
Object[] row = new Object[schema.size()];
for (Integer key : schema.keySet()) {
try {
row[key-1] = schema.get(key).field.get(t);
} catch (IllegalAccessException e) {
throw new RuntimeException("illegal access of field "+schema.get(key).name);
}
}
fill(row);
}
//自定义方法:验证参数字段的有效性,即是否表结构中既有字段
private void validateFields(String...fields){
for (String field : fields) {
if(!validator.containsKey(field)){
throw new RuntimeException("no column "+field+" found");
}
}
}
//自定义方法:验证参数 1是否存在于参数 2数组中
private boolean contains(String field,String...fields){
for (String _field : fields) {
if(field.equals(_field)){
return true;
}
}
return false;
}
//自定义方法:验证参数列是否为数值类型
private boolean isNumeric(ColHead ch){
return types.contains(null != ch.field ? ch.field.getType().getName() : ch.type.getName());
}
//自定义方法:将集合中的所有实体对象转化为原始行添加
public DataFrame fill(Collection<T> cs) {
if(null==cs || cs.size()==0){
throw new RuntimeException("data NullPointer Exception");
}
init((Class<T>)cs.iterator().next().getClass());
for (T t : cs) {
fill(t);
}
return this;
}
//自定义方法:1、查询操作,即列筛选
public DataFrame select(String...fields){
if(fields.length>0){
validateFields(fields);
//调整结构:保留被查询的字段
Object[] keys = validator.keySet().toArray();
for (Object key : keys) {
String _key = key.toString();
if (!contains(_key,fields)) {
Integer pos = validator.remove(_key);
schema.remove(pos);
}
}
//调整数据:保留被查询列值
List<Object[]> copy = new ArrayList<>();
for (Object[] row : this.data) {
Object[] _row = new Object[fields.length];
int index = 0;
for (String field : fields) {
_row[index++] = row[validator.get(field)-1];
}
copy.add(_row);
}
this.data.clear();
this.data = copy;
//再次调整结构:被保留字段的位置
int pos = 0;
Map<Integer,ColHead> bak = new HashMap<>();
bak.putAll(schema);
this.schema.clear();
for (String field : fields) {
Integer key = validator.put(field, ++pos);
this.schema.put(pos,bak.get(key));
}
bak.clear();
}
return this;
}
//自定义方法:2、条件筛选,即行筛选
public DataFrame where(String field, Condition cond, Object...val){
validateFields(field);
Integer pos = validator.get(field);
ColHead ch = schema.get(pos);
switch (cond){
case GE:case GT:case LE:case LT:
if(!isNumeric(ch)) throw new RuntimeException(field+" is not numeric so it can't be conditioned by "+cond);
}
String[] values = null;
if(cond==Condition.IN){
values = new String[val.length];
for (int i = 0; i < values.length; i++) {
values[i] = val[i].toString();
}
}
List<Object[]> copy = new ArrayList<>();
String valFirst = val[0].toString();
for (Object[] row : this.data) {
String colVal = row[pos-1].toString();
switch (cond){
case EQ:
if(colVal.equals(valFirst)){
copy.add(row);
}
continue;
case NE:
if(!colVal.equals(valFirst)){
copy.add(row);
}
continue;
case GT:case GE:case LT:case LE:
double _col = Double.parseDouble(colVal),
_val = Double.parseDouble(valFirst);
switch (cond){
case GT:
if(_col>_val){
copy.add(row);
}
continue;
case GE:
if(_col>=_val){
copy.add(row);
}
continue;
case LT:
if(_col<_val){
copy.add(row);
}
continue;
case LE:
if(_col<=_val){
copy.add(row);
}
continue;
}
case LIKE:
if(colVal.matches(valFirst)){
copy.add(row);
}
continue;
case IN:
if(contains(colVal,values)){
copy.add(row);
}
continue;
}
}
this.data.clear();
this.data = copy;
return this;
}
//自定义方法:3、分组,即将原始数据按照分组字段组合分段
public DataFrame groupBy(String field, String...fields){
validateFields(field);
validateFields(fields);
if(null==groupFields) groupFields = new ArrayList<>(1+fields.length);
if(null==groupData) groupData = new HashMap<>();
groupFields.add(field);
groupFields.addAll(Arrays.asList(fields));
//根据分组字段名称,按顺序提取分组字段对应的列值的下标
int[] ixs = new int[groupFields.size()];
for (int i = 0; i < groupFields.size(); i++) {
String gf = groupFields.get(i);
ixs[i] = validator.get(gf)-1;
}
//分组筛选数据
StringBuilder builder = new StringBuilder();
for (Object[] row : this.data) {
//清空
builder.delete(0,builder.length());
//拼接组合分组字段
int i = 0;
builder.append(row[ixs[i++]]);
for (; i <ixs.length ; i++) {
builder.append(",");
builder.append(row[ixs[i]]);
}
//根据组合字段筛选数据
if(groupData.containsKey(builder.toString())){
groupData.get(builder.toString()).add(row);
}else{
List<Object[]> data = new ArrayList<>();
data.add(row);
groupData.put(builder.toString(),data);
}
}
return this;
}
//自定义方法:4、聚合,可以配合where对分组结果再次筛选
public DataFrame agg(String field, Agg type, String alias){
validateFields(field);
//提取分组字段位置
Integer pos = validator.get(field);
//根据分组字段位置提取ColHead
ColHead colHead = schema.get(pos);
//验证:分组字段类型是否支持聚合类型
switch (type){
case SUM:case MAX:case MIN:case AVG:
if(!isNumeric(colHead))
throw new RuntimeException(type+" Aggregation can't be applied to field [ "+field+" ], because it isn't numeric." );
}
//根据聚合枚举类型确定聚合结果数据类型
Object val = null;
Class c = Float.class;
switch (type){
case COLLECT_LIST:
val = new ArrayList<>();
break;
case COLLECT_SET:
val = new HashSet<>();
break;
}
//分组字段对应列值下标
int groupIx = pos-1;
int count = 0;
//分别对分组聚合和完全聚合进行处理
if(null == groupFields){ //完全聚合
for (Object[] row : this.data) {
count++;
String colVal = row[groupIx].toString();
switch (type){
case AVG:case SUM:case MAX:case MIN:
float v = Float.parseFloat(colVal);
switch (type){
case AVG:case SUM:
val = null==val ? v : (Float)val + v;
continue;
case MIN:
val = null==val ? v : (Float)val > v ? v : val;
continue;
case MAX:
val = null==val ? v : (Float)val < v ? v : val;
continue;
}
continue;
case COLLECT_SET:
((Set)val).add(colVal);
continue;
case COLLECT_LIST:
((List)val).add(colVal);
continue;
}
}
this.validator.clear();
this.schema.clear();
this.data.clear();
this.validator.put(alias,1);
switch (type){
case COLLECT_LIST:case COLLECT_SET:
Collection collect = (Collection) val;
val = Arrays.asList(collect.toArray());
collect.clear();
c = List.class;
break;
}
this.schema.put(1,new ColHead(alias,c));
this.data.add(new Object[]{
type==Agg.AVG ? (Float)val/count :
type==Agg.COUNT ? count : val});
}else{ //分组聚合
if(groupFields.contains(field)){
throw new RuntimeException("aggregation operation shouldn't be applied to group field "+field);
}
//重置数据
List<Object[]> copy = new ArrayList<>();
List<Object> row = new ArrayList<>();
for (Map.Entry<String, List<Object[]>> e : groupData.entrySet()) {
row.addAll(Arrays.asList(e.getKey().split(",")));
val = null;
switch (type){
case COLLECT_LIST:
val = new ArrayList<>();
break;
case COLLECT_SET:
val = new HashSet<>();
break;
}
count = 0;
for (Object[] row2 : e.getValue()) {
count++;
String colVal = row2[groupIx].toString();
switch (type){
case AVG:case SUM:case MAX:case MIN:
float v = Float.parseFloat(colVal);
switch (type){
case AVG:case SUM:
val = null==val ? v : (Float)val + v;
continue;
case MIN:
val = null==val ? v : (Float)val > v ? v : val;
continue;
case MAX:
val = null==val ? v : (Float)val < v ? v : val;
continue;
}
continue;
case COLLECT_SET:
((Set)val).add(colVal);
continue;
case COLLECT_LIST:
((List)val).add(colVal);
continue;
}
}
switch (type){
case COLLECT_LIST:case COLLECT_SET:
Collection collect = (Collection) val;
val = Arrays.asList(collect.toArray());
collect.clear();
c = List.class;
break;
}
row.add(type==Agg.AVG ? (Float)val/count :
type==Agg.COUNT ? count : val);
copy.add(row.toArray());
row.clear();
}
//重置结构
pos = 0;
Map<String,Integer> _validator = new HashMap<>();
Map<Integer,ColHead> _schema = new HashMap<>();
for (String gf : groupFields) {
_validator.put(gf,++pos);
ColHead ch = schema.get(validator.get(gf));
_schema.put(pos,ch);
}
_validator.put(alias,++pos);
_schema.put(pos,new ColHead(alias,c));
this.validator.clear();
this.schema.clear();
this.data.clear();
this.validator = _validator;
this.schema = _schema;
this.data = copy;
}
//聚合完成之后清空临时聚合字段列表和聚合数据
this.groupFields.clear();
this.groupData.clear();
return this;
}
//自定义方法:5、排序
public DataFrame orderBy(String field, boolean...asc){
validateFields(field);
int pos = validator.get(field);
boolean isNum = isNumeric(schema.get(pos));
final int IX = pos - 1;
boolean ascend = asc==null || (asc.length==1 && asc[0]);
this.data.sort((Object[] o1, Object[] o2)-> isNum ?
(ascend ? (int)(Math.ceil(Float.parseFloat(o1[IX].toString())-Float.parseFloat(o2[IX].toString()))*100) : (int)(Math.ceil(Float.parseFloat(o2[IX].toString())-Float.parseFloat(o1[IX].toString())))*100) :
(ascend ? o1.toString().compareTo(o2.toString()) : o2.toString().compareTo(o1.toString())));
return this;
}
//将当前数据集和参数数据集another以参数joinFields进行关键
public DataFrame join(DataFrame another, String...joinFields){
if(joinFields.length==0 || null==joinFields[0]){
throw new RuntimeException("no join fields exception on join opertion");
}
final String thisJoinField = joinFields[0],
anotherJoinField = joinFields.length==1 ? joinFields[0] : joinFields[1];
validateFields(thisJoinField);
another.validateFields(anotherJoinField);
//schema合并
Map<String,Integer> aValidator = (Map<String,Integer>)another.validator;
Map<Integer,ColHead> aSchema = (Map<Integer,ColHead>)another.schema;
boolean thisSmall = this.data.size() <= another.data.size();
Map<String,Integer> copyValidator = new HashMap<>();
Map<Integer,ColHead> smallSchema = null,bigSchema = null,copySchema = new HashMap<>();
List<Object[]> smallData = null,bigData = null;
int smallIx = -1,bigIx = -1;
if(thisSmall){
smallSchema = this.schema;
bigSchema = aSchema;
smallData = this.data;
bigData = another.data;
smallIx = this.validator.get(thisJoinField)-1;
bigIx = aValidator.get(anotherJoinField)-1;
}else{
smallSchema = aSchema;
bigSchema = this.schema;
smallData = another.data;
bigData = this.data;
smallIx = aValidator.get(anotherJoinField)-1;
bigIx = this.validator.get(thisJoinField)-1;
}
int pos = 0;
for (Map.Entry<Integer, ColHead> e : smallSchema.entrySet()) {
copyValidator.put(e.getValue().name, ++pos);
copySchema.put(pos,e.getValue());
}
int rmStartIx = pos;
for (Map.Entry<Integer, ColHead> e : bigSchema.entrySet()) {
copyValidator.put(e.getValue().name, ++pos);
copySchema.put(pos,e.getValue());
}
final int COL_SIZE = pos;
//数据合并
//首先构建数据关系
Map<Integer,List<Object[]>> smallJoinBig = new HashMap<>();
pos = 0;
for (Object[] smallRow : smallData) {
String smallVal = smallRow[smallIx].toString();
List<Object[]> rows = new ArrayList<>();
for (Object[] bigRow : bigData) {
String bigVal = bigRow[bigIx].toString();
if(smallVal.equals(bigVal)){
rows.add(bigRow);
}
}
smallJoinBig.put(pos,rows);
pos++;
}
List<Object[]> copy = new ArrayList<>();
List<Object> row = new ArrayList<>(COL_SIZE);
for (Map.Entry<Integer, List<Object[]>> e : smallJoinBig.entrySet()) {
row.addAll(Arrays.asList(this.data.get(e.getKey())));
for (Object[] bigRight : e.getValue()) {
row.addAll(Arrays.asList(bigRight));
copy.add(row.toArray());
row = row.subList(0,rmStartIx);
}
row.clear();
}
this.data.clear();
this.validator.clear();
this.schema.clear();
this.data = copy;
this.validator = copyValidator;
this.schema = copySchema;
return this;
}
//打印表结构
public void printSchema(){
if(null==schema){
throw new RuntimeException("no schema");
}
System.out.println("root");
for (Map.Entry<Integer, ColHead> e : schema.entrySet()) {
ColHead ch = e.getValue();
System.out.println(" |--> "+e.getKey()+" , "+ ch.name+" , "
+ (ch.field==null ? ch.type.getName() : ch.field.getType().getName()));
}
}
//自定义方法:结果展示
public void show(int...rows){
if(rows.length>0 && rows[0]<=0){
throw new RuntimeException("show "+rows[0]+" is not allowed");
}
//表头
System.out.println("---------------------------------------------");
for (Integer key : schema.keySet()) {
System.out.print(schema.get(key).name);
System.out.print("\t");
}
System.out.println("\n---------------------------------------------");
//数据
int row = rows.length==0 || rows[0]>this.data.size() ? this.data.size() : rows[0];
for (int i = 0; i < row; i++) {
Object[] cols = this.data.get(i);
for (int j = 0; j < cols.length; j++) {
System.out.print(cols[j]);
System.out.print("\t");
}
System.out.println();
}
this.data.clear();
}
}
测试类:
package cn.kgc.java2scala;
import java.util.Arrays;
import java.util.List;
public class Test {
public static class Student{
private String name;
private String gender;
private Integer age;
public Student(String name, String gender, Integer age) {
this.name = name;
this.gender = gender;
this.age = age;
}
}
public static class Student2{
private String name;
private String province;
private String city;
public Student2(String name, String province, String city) {
this.name = name;
this.province = province;
this.city = city;
}
}
public static class Score{
private String name;
private String subject;
private Integer score;
public Score(String name, String subject, Integer score) {
this.name = name;
this.subject = subject;
this.score = score;
}
}
public static void main(String[] args) {
List<Student> arr = Arrays.asList(
new Student("henry", "male", 38),
new Student("pola", "female", 28),
new Student("ariel", "female", 18),
new Student("jack", "male", 26),
new Student("apple", "male", 26),
new Student("rose", "female", 22));
List<Student2> arr2 = Arrays.asList(
new Student2("henry", "jiangsu", "nanjing"),
new Student2("pola", "anhui", "hefei"),
new Student2("ariel", "jiangsu", "yancheng"),
new Student2("jack", "jiangsu", "suzhou"),
new Student2("apple", "anhui", "wuhu"),
new Student2("rose", "anhui", "liuhe"));
List<Score> arr3 = Arrays.asList(
new Score("henry", "java", 77),
new Score("pola", "java", 88),
new Score("ariel", "java", 76),
new Score("jack", "java", 66),
new Score("apple", "java", 75),
new Score("rose", "java", 65),
new Score("henry", "mysql", 84),
new Score("pola", "mysql", 83),
new Score("ariel", "mysql", 79),
new Score("jack", "mysql", 68),
new Score("apple", "mysql", 62),
new Score("rose", "mysql", 82)
);
DataFrame df2 = new DataFrame<Student2>().fill(arr2);
DataFrame df3 = new DataFrame<Score>().fill(arr3);
new DataFrame<Student>()
.fill(arr)
// .select("gender","age")
// .select("gender","name","age")
// .where("gender", Array.Condition.EQ,"female")
// .where("name",Array.Condition.IN,"pola","rose")
// .where("age",Array.Condition.GT,25)
// .where("age",Array.Condition.LT,35)
// .where("name",Array.Condition.LIKE,"^a.*")
// .groupBy("gender")
// .agg("age", Array.Agg.AVG,"avgAge")
// .where("avgAge", Array.Condition.GT,new Integer(25))
// .orderBy("age",false)
// .join(df2,"name")
.join(df3,"name")
// .printSchema();
.show();
}
}