Java | 实现一个 ORM 到底多简单

实现一个 ORM 到底多简单



在这篇文章中主要用到了注解、反射、动态代理、正则、Jexl3 表达式来实现一个 ORM, 在大多数框架中,都会用到这些东西,在这里只是简单的使用一下,而不是原理分析

原理

在使用的 ORM 框架中,我可以像操作对象一样操作数据的存储,这是怎么实现的,我们知道数据库是认识 SQL 语句的,但并不认识java bean 呀!同时我们在使用ORM时,需要根据ORM框架的规定定义我们的bean,这是为什么?

这是因为 ORM 为我们提供了将对象操作转化为对应的 SQL语句,例如 save(bean), 这时就需要转化成一个 insert 语句,update(bean) 这时就需要转成成对应的 update 语句

通常 insert 语句格式为

insert into 表名 (列名) values()

update 语句为

update 表名 set 列名 =where id =

上面的格式可以看出,如果我们能从对象中得出 表名 列名 ,我们也可以写一个简单的ORM框架

ORM 实现

1. 通过注解来将 Java Bean 和数据库字段关联

上篇文章中提到了一下注解,以及自定义注解和解析注解的方法,通过使用注解,我们可以完成一个简单的ORM

要想实现对数据库的操作,我们必须知道数据表名以及表中的字段名称以及类型,正如hibernate 使用注解标识 model 与数据库的映射关系一样,这里我使用了Java Persistence API

注解说明

注解作用使用说明
@Entity标记这是一个实体标注在类上,标明该类可以被 ORM 处理
@Table标记实体对应的表标注在类上,标明该类对应的数据库标明
@Id标记该字段是id标注在字段上,标明该字段为 id
@Column标记该字段对应的列信息标记在字段上,标明对应的列信息,主要对应列名

字段属性表
用来存储对象中字段与数据表列的对应关系

package com.zyndev.tool.fastsql.core;

import lombok.Data;

import javax.persistence.GenerationType;

/**
 * The type Db column info.
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.1
 */
@Data
class DBColumnInfo {

    /**
     * (Optional) The primary key generation strategy
     * that the persistence provider must use to
     * generate the annotated entity primary key.
     */
    private GenerationType strategy = GenerationType.AUTO;

    private String fieldName;

    /**
     * The name of the column
     */
    private String columnName;

    /**
     * Whether the column is a unique key.
     */
    private boolean unique;

    /**
     * Whether the database column is nullable.
     */
    private boolean nullable = true;

    /**
     * Whether the column is included in SQL INSERT
     */
    private boolean insertAble = true;

    /**
     * Whether the column is included in SQL UPDATE
     */
    private boolean updatable = true;

    /**
     * The SQL fragment that is used when
     * generating the DDL for the column.
     */
    private String columnDefinition;

    /**
     * The name of the table that contains the column.
     * If absent the column is assumed to be in the primary table.
     */
    private String table;

    /**
     * (Optional) The column length. (Applies only if a
     * string-valued column is used.)
     */
    private int length =  255;

    private boolean id = false;

}

2. 反射工具类

提供一些常用的反射操作

通过反射我们可以动态的得到一个类所有的成员变量的信息,同时为这些变量取值或者赋值

package com.zyndev.tool.fastsql.util;


import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
 * The type Bean reflection util.
 *
 * @author yunan.zhang zyndev@gmail.com
 * @version 0.0.3
 * @date 2017年12月26日 16点36分
 */
public class BeanReflectionUtil {

    public static Object getPrivatePropertyValue(Object obj,String propertyName)throws Exception{
        Class cls=obj.getClass();
        Field field=cls.getDeclaredField(propertyName);
        field.setAccessible(true);
        Object retvalue=field.get(obj);
        return retvalue;
    }

    /**
     * Gets static field value.
     */
    public static Object getStaticFieldValue(String className, String fieldName) throws Exception {
        Class cls = Class.forName(className);
        Field field = cls.getField(fieldName);
        return field.get(cls);
    }

    /**
     * Gets field value.
     */
    public static Object getFieldValue(Object obj, String fieldName) throws Exception {
        Class cls = obj.getClass();
        Field field = cls.getDeclaredField(fieldName);
        field.setAccessible(true);
        return field.get(obj);
    }

    /**
     * Invoke method object.
     */
    public Object invokeMethod(Object owner, String methodName, Object[] args) throws Exception {
        Class cls = owner.getClass();
        Class[] argclass = new Class[args.length];
        for (int i = 0, j = argclass.length; i < j; i++) {
            argclass[i] = args[i].getClass();
        }
        @SuppressWarnings("unchecked")
        Method method = cls.getMethod(methodName, argclass);
        return method.invoke(owner, args);
    }

    /**
     * Invoke static method object.
     */
    public Object invokeStaticMethod(String className, String methodName, Object[] args) throws Exception {
        Class cls = Class.forName(className);
        Class[] argClass = new Class[args.length];
        for (int i = 0, j = argClass.length; i < j; i++) {
            argClass[i] = args[i].getClass();
        }
        @SuppressWarnings("unchecked")
        Method method = cls.getMethod(methodName, argClass);
        return method.invoke(null, args);
    }

    /**
     * New instance object.
     */
    public static Object newInstance(String className) throws Exception {
        Class clazz = Class.forName(className);
        @SuppressWarnings("unchecked")
        java.lang.reflect.Constructor cons = clazz.getConstructor();
        return cons.newInstance();
    }

    /**
     * New instance object.
     */
    public static Object newInstance(Class clazz) throws Exception {
        @SuppressWarnings("unchecked")
        java.lang.reflect.Constructor cons = clazz.getConstructor();
        return cons.newInstance();
    }

    /**
     * Get bean declared fields field [ ].
     */
    public static Field[] getBeanDeclaredFields(String className) throws ClassNotFoundException {
        Class clazz = Class.forName(className);
        return clazz.getDeclaredFields();
    }

    /**
     * Get bean declared methods method [ ].
     */
    public static Method[] getBeanDeclaredMethods(String className) throws ClassNotFoundException {
        Class clazz = Class.forName(className);
        return clazz.getMethods();
    }

    /**
     * Copy fields.
     */
    public static void copyFields(Object source, Object target) {
        try {
            List<Field> list = new ArrayList<>();
            Field[] sourceFields = getBeanDeclaredFields(source.getClass().getName());
            Field[] targetFields = getBeanDeclaredFields(target.getClass().getName());
            Map<String, Field> map = new HashMap<>(targetFields.length);
            for (Field field : targetFields) {
                map.put(field.getName(), field);
            }
            for (Field field : sourceFields) {
                if (map.get(field.getName()) != null) {
                    list.add(field);
                }
            }
            for (Field field : list) {
                Field tg = map.get(field.getName());
                tg.setAccessible(true);
                tg.set(target, getFieldValue(source, field.getName()));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

3. 简单的 model 示例

package com.zyndev.tool.fastsql;

import javax.persistence.*;

/**
 * @author 张瑀楠 zyndev@gmail.com
 * @date   2017/11/30 下午11:21
 */
@Data
@Entity
@Table(name = "tb_student")
public class Student {

    @Id
    @Column
    private Integer id;

    @Column
    private String name;

    @Column(updatable = true, insertable = true, nullable = false)
    private Integer age;

}

4. 注解解析

将对象的上的注解进行解析,得到对应关系

package com.zyndev.tool.fastsql.core;


import com.zyndev.tool.fastsql.util.StringUtil;

import javax.persistence.Column;
import javax.persistence.Id;
import javax.persistence.Table;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
 * The type Annotation parser.
 * <p>
 * 注解解析工具
 *
 * @author yunan.zhang zyndev@gmail.com
 * @version 0.0.3
 * @since 2017-12-26 16:59:07
 */
public class AnnotationParser {

    /**
     * 存储类名和数据库表名的关系
     * 使用三个cache 主要为了减少反射的次数,提高效率
     */
    private final static Map<String, String> tableNameCache = new HashMap<>(30);

    /**
     * 存储类名和数据库列的关联关系
     */
    private final static Map<String, String> tableAllColumnNameCache = new HashMap<>(30);

    /**
     * 存储类名和对应的数据库全列名的关系
     */
    private final static Map<String, List<DBColumnInfo>> tableAllDBColumnCache = new HashMap<>(30);

    /**
     * Gets table name.
     * 获得表名
     * 判断是否有@Table注解,如果有得到注解的name,如果name为空,则使用类名做为表名
     * 如果没有@Table返回null
     *
     * @param <E>    the type parameter
     * @param entity the entity
     * @return the table name
     */
    public static <E> String getTableName(E entity) {
        String tableName = tableNameCache.get(entity.getClass().getName());
        if (tableName == null) {
            Table table = entity.getClass().getAnnotation(Table.class);
            if (table != null && StringUtil.isNotBlank(table.name())) {
                tableName = table.name();
            } else {
                tableName = entity.getClass().getSimpleName();
            }
            tableNameCache.put(entity.getClass().getName(), tableName);
        }
        return tableName;
    }

    /**
     * Gets all db column info.
     */
    public static <E> List<DBColumnInfo> getAllDBColumnInfo(E entity) {
        List<DBColumnInfo> dbColumnInfoList = tableAllDBColumnCache.get(entity.getClass().getName());
        if (dbColumnInfoList == null) {
            dbColumnInfoList = new ArrayList<>();
            DBColumnInfo dbColumnInfo;
            Field[] fields = entity.getClass().getDeclaredFields();
            for (Field field : fields) {
                Column column = field.getAnnotation(Column.class);
                if (column != null) {
                    dbColumnInfo = new DBColumnInfo();
                    if (StringUtil.isBlank(column.name())) {
                        dbColumnInfo.setColumnName(field.getName());
                    } else {
                        dbColumnInfo.setColumnName(column.name());
                    }
                    if (null != field.getAnnotation(Id.class)) {
                        dbColumnInfo.setId(true);
                    }
                    dbColumnInfo.setFieldName(field.getName());
                    dbColumnInfoList.add(dbColumnInfo);
                }
            }
            tableAllDBColumnCache.put(entity.getClass().getName(), dbColumnInfoList);
        }
        return dbColumnInfoList;
    }

    /**
     * 返回表字段的所有字段  column1,column2,column3
     *
     * @param <E>    the type parameter
     * @param entity the entity
     * @return string
     */
    public static <E> String getTableAllColumn(E entity) {

        String allColumn = tableAllColumnNameCache.get(entity.getClass().getName());
        if (allColumn == null) {
            List<DBColumnInfo> dbColumnInfoList = getAllDBColumnInfo(entity);
            StringBuilder allColumnsInfo = new StringBuilder();
            int i = 1;
            for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
                allColumnsInfo.append(dbColumnInfo.getColumnName());
                if (i != dbColumnInfoList.size()) {
                    allColumnsInfo.append(",");
                }
                i++;
            }
            allColumn = allColumnsInfo.toString();
            tableAllColumnNameCache.put(entity.getClass().getName(), allColumn);
        }
        return allColumn;

    }
}

5. 数据库操作

数据库交互使用spring 提供的 JdbcTemplate,这里没有在自己写一套 DBUtil

6. 结合反射实现查询操作

保存一个entity

保存操作相对简单,这里主要是将 entity 转换为 insert 语句

/**
 * Save int.
 * @param entity the entity
 * @return the int
 */
@Override
public int save(Object entity) {
    try {
        String tableName = AnnotationParser.getTableName(entity);
        StringBuilder property = new StringBuilder();
        StringBuilder value = new StringBuilder();
        List<Object> propertyValue = new ArrayList<>();
        List<DBColumnInfo> dbColumnInfoList = AnnotationParser.getAllDBColumnInfo(entity);

        for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
            if (dbColumnInfo.isId() || !dbColumnInfo.isInsertAble()) {
                continue;
            }
            // 不为null
            Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
            if (o != null) {
                property.append(",").append(dbColumnInfo.getColumnName());
                value.append(",").append("?");
                propertyValue.add(o);
            }
        }
        String sql = "insert into " + tableName + "(" + property.toString().substring(1) + ") values(" + value.toString().substring(1) + ")";
        return this.getJdbcTemplate().update(sql, propertyValue.toArray());
    } catch (Exception e) {
        e.printStackTrace();
    }
    return 0;
}

更新操作

更新操作相对于 保存来说,多了一步确定where 语句操作

/**
 * Update int.
 *
 * @param entity     the entity
 * @param ignoreNull the ignore null
 * @param columns    the columns
 * @return the int
 */
@Override
public int update(Object entity, boolean ignoreNull, String... columns) {
    try {
        String tableName = AnnotationParser.getTableName(entity);
        StringBuilder property = new StringBuilder();
        StringBuilder where = new StringBuilder();
        List<Object> propertyValue = new ArrayList<>();
        List<Object> wherePropertyValue = new ArrayList<>();
        List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
        for (DBColumnInfo dbColumnInfo : dbColumnInfos) {

            Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
            if (dbColumnInfo.isId()) {
                where.append(" and ").append(dbColumnInfo.getColumnName()).append(" = ? ");
                wherePropertyValue.add(o);
            } else if (ignoreNull || o != null) {
                property.append(",").append(dbColumnInfo.getColumnName()).append("=?");
                propertyValue.add(o);
            }
        }

        if (wherePropertyValue.isEmpty()) {
            throw new IllegalArgumentException("更新表 [" + tableName + "] 无法找到id, 请求数据:" + entity);
        }

        String sql = "update " + tableName + " set " + property.toString().substring(1) + " where " + where.toString().substring(5);
        propertyValue.addAll(wherePropertyValue);
        return this.getJdbcTemplate().update(sql, propertyValue.toArray());
    } catch (Exception e) {
        e.printStackTrace();
    }
    return 0;
}

删除操作

相对 update 简单一点

/**
 * Delete int.
 * <p>根据id 删除对应的数据</p>
 *
 * @param entity the entity
 * @return the int
 */
@Override
public int delete(Object entity) {
    try {
        String tableName = AnnotationParser.getTableName(entity);
        StringBuilder where = new StringBuilder(" 1=1 ");
        List<Object> whereValue = new ArrayList<>(5);
        List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
        for (DBColumnInfo dbColumnInfo : dbColumnInfos) {
            if (dbColumnInfo.isId()) {
                Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
                if (null != o) {
                    whereValue.add(o);
                }
                where.append(" and `").append(dbColumnInfo.getColumnName()).append("` = ? ");
            }
        }

        if (whereValue.size() == 0) {
            throw new IllegalStateException("delete " + tableName + " id 无对应值,不能删除");
        }
        String sql = "delete from  " + tableName + " where " + where.toString();
        return this.getJdbcTemplate().update(sql, whereValue);
    } catch (Exception e) {
        e.printStackTrace();
    }
    return 0;
}

通过上面的示例,就可以简单的实现一个 ORM, 为了更好的使用,我们还需要提供自己写 SQL 的方案

使用动态代理实现 @Query @Select 类似功能

1. 动态代理

这里直接使用基于 JDK 动态代理实现

2. 注解

Java Persistence API 中没有我们需要的 @Query 和 @Param 这里我们自定义一下这两个注解,同时为了让 Query 支持返回主键,在定义一个 ReturnGeneratedKey 注解

Query.java

package com.zyndev.tool.fastsql.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 查询操作 sql
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.1
 * @since  2017/12/22 17:26
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Query {

    /**
     * sql 语句
     */
    String value();
}

Param.java

package com.zyndev.tool.fastsql.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.1
 * @since  2017/12/22 17:29
 */
@Target( { ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
public @interface Param {

    /**
     * 指出这个值是 SQL 语句中哪个参数的值,使用命名参数
     */
    String value();
}

ReturnGeneratedKey.java

package com.zyndev.tool.fastsql.annotation;


import java.lang.annotation.*;


/**
 * 返回主键
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.3
 *
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ReturnGeneratedKey {
}

3. 表设计

CREATE TABLE `tb_user` (
  `id` int(11) NOT NULL AUTO_INCREMENT,
  `uid` varchar(40) DEFAULT NULL,
  `account_name` varchar(40) DEFAULT NULL,
  `nick_name` varchar(23) DEFAULT NULL,
  `password` varchar(30) DEFAULT NULL,
  `phone` varchar(16) DEFAULT NULL,
  `register_time` timestamp NULL DEFAULT NULL,
  `update_time` timestamp NULL DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=3 DEFAULT CHARSET=utf8

4. model

这里使用一个 User.java 作为例子:

User.java

package com.zyndev.tool.fastsql.repository;

import lombok.*;

import java.io.Serializable;
import java.util.Date;

/**
 * @author 张瑀楠 zyndev@gmail.com
 * @version 1.0
 * @date 2017-12-27 15:04:13
 */
@Data
public class User implements Serializable {

    private static final long serialVersionUID = 1L;

    private Integer id;

    private String uid;

    private String accountName;

    private String nickName;

    private String password;

    private String phone;

    private Date registerTime;

    private Date updateTime;

}

5. repository

package com.zyndev.tool.fastsql.repository;

import com.zyndev.tool.fastsql.annotation.Param;
import com.zyndev.tool.fastsql.annotation.Query;
import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
import org.springframework.stereotype.Repository;

import java.util.List;
import java.util.Map;

/**
 * 这里应该有描述
 *
 * @version 1.0
 * @author 张瑀楠 zyndev@gmail.com
 * @date 2017 /12/22 18:13
 */
@Repository
public interface UserRepository {

    @Query("select count(*) from tb_user")
    public Integer getCount();

    @Query("delete from tb_user where id = ?1")
    public Boolean deleteById(int id);

    @Query("select count(*) from tb_user where password = ?1 ")
    public int getCountByPassword(@Param("password") String password);

    @Query("select uid from tb_user where password = ?1 ")
    public String getUidByPassword(@Param("password") String password);

    @Query("select * from tb_user where id = :id ")
    public User getUserById(@Param("id") Integer id);

    @Query("select * " +
            " from tb_user " +
            " where account_name = :accountName ")
    public List<User> getUserByAccountName(@Param("accountName") String accountName);

    @Query("insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time) " +
            "values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
    public int saveUser(@Param("id") Integer id, @Param("user") User user);

    @ReturnGeneratedKey
    @Query("insert into tb_user(account_name, password, uid, nick_name, register_time, update_time) " +
            "values(:user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
    public int saveUser(@Param("user") User user);
}

7. 大体流程

在上面的我们已经完成一些准备工作,包括:

  1. 注解的定义
  2. 表的设计
  3. model 的设计
  4. Repository 的设计

接下来,我们看看如何将这些整合在一起

大致流程:

  1. 为 Repository 生成代理
  2. 将生成代理放入 Spring IOC 容器中
  3. 当代理的方法被调用时,得到方法的 @Query , @Param,@ReturnGeneratedKey 注解,并取得方法的返回值
  4. 重写 Query的sql,并执行,根据方法的返回类型,封装SQL返回结果集

8. 代理使用

FacadeProxy.java

为 Repository 生成代理,当代理方法执行时,回调 invoke 方法,invoke 中逻辑写到StatementParser.java中,防止类功能过大

package com.zyndev.tool.fastsql.core;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

/**
 * 生成代理
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.1
 * @since  2017 /12/23 上午12:40
 */
@SuppressWarnings("unchecked")
public class FacadeProxy implements InvocationHandler {

    private final static Log logger = LogFactory.getLog(FacadeProxy.class);

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        return StatementParser.invoke(proxy, method, args);
    }

    /**
     * New mapper proxy t.
     */
    protected static <T> T newMapperProxy(Class<T> mapperInterface) {
        logger.info(" 生成代理:" + mapperInterface.getName());
        ClassLoader classLoader = mapperInterface.getClassLoader();
        Class<?>[] interfaces = new Class[]{mapperInterface};
        FacadeProxy proxy = new FacadeProxy();
        return (T) Proxy.newProxyInstance(classLoader, interfaces, proxy);
    }
}

9. 将生成代理放入 Spring IOC 容器中

在这里使用了 BeanFactoryAware,关于这部分内容会单独写一篇,这里不在详细说明

package com.zyndev.tool.fastsql.core;

import com.zyndev.tool.fastsql.util.ClassScanner;
import com.zyndev.tool.fastsql.util.StringUtil;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.stereotype.Repository;

import java.io.IOException;
import java.util.Set;

/**
 * The type Fast sql repository registrar.
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 1.0
 * @date 2018 /2/23 12:26
 */
public class FastSqlRepositoryRegistrar implements ImportBeanDefinitionRegistrar, BeanFactoryAware {

    private ConfigurableListableBeanFactory beanFactory;

    @Override
    public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {
        System.out.println("FastSqlRepositoryRegistrar registerBeanDefinitions ");
        String basePackage = "com.sparrow";
        ClassScanner classScanner = new ClassScanner();
        Set<Class<?>> classSet = null;
        try {
            classSet = classScanner.getPackageAllClasses(basePackage, true);
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        }
        for (Class clazz : classSet) {
            if (clazz.getAnnotation(Repository.class) != null) {
                beanFactory.registerSingleton(StringUtil.firstCharToLowerCase(clazz.getSimpleName()), FacadeProxy.newMapperProxy(clazz));
            }
        }
    }

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
    }
}

10. invoke方法处理

在前面生成动态的代理的时候,可以看到,所有的invoke逻辑由StatementParser.java处理,这也是一个重量级的方法

invoke执行流程说明:

invoke(Object proxy, Method method, Object[] args)

  1. 得到方法的返回类型
  2. 得到方法的**@Query**注解,取得需要执行的sql语句,无法取到sql则抛异常
  3. 获得方法的参数,并将参数顺序对应为 ?1->arg0 ?2->arg1 …
  4. 获得方法的参数和参数上@Param注解,并将参数与对应的Param的名称关联:param1->arg0 password->arg1
  5. 判断sql是select还是其他,使用正则 (?i)select([\\s\\S]*?)
  6. 重写sql
  7. 如果不是 select 语句,判断是否是 @ReturnGeneratedKey 注解
  8. 如果无 @ReturnGeneratedKey 则直接执行语句并返回对应的结果
  9. 如有有 @ReturnGeneratedKey 并且是 insert 语句则返回生成的主键
  10. 如果是 select 语句,则执行select 语句,并根据方法的返回类型封装结果集

关于重写sql

@Query("insert into tb_user(id, 
account_name, 
password, 
uid, 
nick_name, 
register_time, 
update_time) values(
    :id, 
    :user.accountName, 
    :user.password, 
    :user.uid, 
    :user.nickName, 
    :user.registerTime, 
    :user.updateTime )")
    public int saveUser(@Param("id") Integer id, @Param("user") User user); 

首先获取sql

insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )

可以看出这并不是标准的 sql 也不是 jdbc 可以识别的sql,这里我们使用正则**?\d+(.[A-Za-z]+)?|:[A-Za-z0-9]+(.[A-Za-z]+)?**

来提取出 ?1 ?2 :1 :2 :id :user.accountName 的特殊标志,并将其替换为 ?

替换过程中没替换一个 ? 则记录对应的 ?所代表的数值

替换后的SQL为

insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
values(?, ?, ?, ?, ?, ?, ? )

这样的sql就可以被 jdbc 处理了
同时参数允许为:

:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime

这里的 id 可以从参数中 id 直接获取,
:user.accountName 则需要从参数 :user 即 user 中通过反射获取,这样 SQL 的重写就完成了

返回结果集封装可以通过反射,可以直接看下面代码

StatementParser.java

package com.zyndev.tool.fastsql.core;

import com.sun.org.apache.bcel.internal.generic.IF_ACMPEQ;
import com.zyndev.tool.fastsql.annotation.Param;
import com.zyndev.tool.fastsql.annotation.Query;
import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
import com.zyndev.tool.fastsql.convert.BeanConvert;
import com.zyndev.tool.fastsql.convert.ListConvert;
import com.zyndev.tool.fastsql.convert.SetConvert;
import com.zyndev.tool.fastsql.util.BeanReflectionUtil;
import com.zyndev.tool.fastsql.util.StringUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.jdbc.core.*;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.jdbc.support.rowset.SqlRowSet;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * sql 语句解析
 * <p>
 * 暂时只能处理 select count(*) from tb_user 类似语句
 *
 * @author 张瑀楠 zyndev@gmail.com
 * @version 0.0.1
 * @since 2017 /12/23 下午12:11
 */
class StatementParser {

    private final static Log logger = LogFactory.getLog(StatementParser.class);

    private static PreparedStatementCreator getPreparedStatementCreator(final String sql, final Object[] args, final boolean returnKeys) {
        PreparedStatementCreator creator = new PreparedStatementCreator() {

            @Override
            public PreparedStatement createPreparedStatement(Connection con) throws SQLException {
                PreparedStatement ps = con.prepareStatement(sql);
                if (returnKeys) {
                    ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
                } else {
                    ps = con.prepareStatement(sql);
                }

                if (args != null) {
                    for (int i = 0; i < args.length; i++) {
                        Object arg = args[i];
                        if (arg instanceof SqlParameterValue) {
                            SqlParameterValue paramValue = (SqlParameterValue) arg;
                            StatementCreatorUtils.setParameterValue(ps, i + 1, paramValue,
                                    paramValue.getValue());
                        } else {
                            StatementCreatorUtils.setParameterValue(ps, i + 1,
                                    SqlTypeValue.TYPE_UNKNOWN, arg);
                        }
                    }
                }
                return ps;
            }
        };
        return creator;
    }

    /**
     * 此处对 Repository 中方法进行解析,解析成对应的sql 和 参数
     * <p>
     * sql 来自于 @Query 注解的 value
     * 参数 来自方法的参数
     * <p>
     * 注意根据返回值的不同封装结果集
     *
     * @param proxy  执行对象
     * @param method 执行方法
     * @param args   参数
     * @return object
     */
    static Object invoke(Object proxy, Method method, Object[] args) throws Exception {

        JdbcTemplate jdbcTemplate = DataSourceHolder.getInstance().getJdbcTemplate();

        boolean logDebug = logger.isDebugEnabled();

        String methodReturnType = method.getReturnType().getName();
        Query query = method.getAnnotation(Query.class);

        if (null == query || StringUtil.isBlank(query.value())) {
            logger.error(method.toGenericString() + " 无 query 注解或 SQL 为空");
            throw new IllegalStateException(method.toGenericString() + " 无 query 注解或 SQL 为空");
        }

        String originSql = query.value().trim();

        System.out.println("sql:" + query.value());
        Map<String, Object> namedParamMap = new HashMap<>();
        Parameter[] parameters = method.getParameters();
        if (args != null && args.length > 0) {
            for (int i = 0; i < args.length; ++i) {
                Param param = parameters[i].getAnnotation(Param.class);
                if (null != param) {
                    namedParamMap.put(param.value(), args[i]);
                }
                namedParamMap.put("?" + (i + 1), args[i]);
            }
        }

        if (logDebug) {
            logger.debug("执行 sql: " + originSql);
        }

        // 判断 sql 类型, 判断是否为 select 开头语句
        boolean isQuery = originSql.trim().matches("(?i)select([\\s\\S]*?)");
        Object[] params = null;
        // rewrite sql
        if (null != args && args.length > 0) {
            List<String> results = StringUtil.matches(originSql, "\\?\\d+(\\.[A-Za-z]+)?|:[A-Za-z0-9]+(\\.[A-Za-z]+)?");
            if (results.isEmpty()) {
                params = args;
            } else {
                params = new Object[results.size()];
                for (int i = 0; i < results.size(); ++i) {
                    if (results.get(i).charAt(0) == ':') {
                        originSql = originSql.replaceFirst(results.get(i), "?");
                        // 判断是否是 param.param 的格式
                        if (!results.get(i).contains(".")) {
                            params[i] = namedParamMap.get(results.get(i).substring(1));
                        } else {
                            String[] paramArgs = results.get(i).split("\\.");
                            Object param = namedParamMap.get(paramArgs[0].substring(1));
                            params[i] = BeanReflectionUtil.getFieldValue(param, paramArgs[1]);
                        }
                        continue;
                    }
                    int paramIndex = Integer.parseInt(results.get(i).substring(1));
                    originSql = originSql.replaceFirst("\\?" + paramIndex, "?");
                    params[i] = namedParamMap.get(results.get(i));
                }
            }
        }


        System.out.println("execute sql:" + originSql);
        System.out.print("params : ");
        if (null != params) {
            for (Object o : params) {
                System.out.print(o + ",\t");
            }
        }
        System.out.println("\n");


        /**
         * 如果返回值是基本类型或者其包装类
         */
        System.out.println(methodReturnType);
        if (isQuery) {
            // 查询方法
            if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
                return jdbcTemplate.queryForObject(originSql, params, Integer.class);
            } else if ("java.lang.String".equals(methodReturnType)) {
                return jdbcTemplate.queryForObject(originSql, params, String.class);
            } else if ("java.util.List".equals(methodReturnType) || "java.util.Set".equals(methodReturnType)) {
                String typeName = null;
                Type returnType = method.getGenericReturnType();
                if (returnType instanceof ParameterizedType) {
                    Type[] types = ((ParameterizedType) returnType).getActualTypeArguments();
                    if (null == types || types.length > 1) {
                        throw new IllegalArgumentException("当返回值为 list 时,必须标明具体类型,且只有一个");
                    }
                    typeName = types[0].getTypeName();
                }
                Object obj = BeanReflectionUtil.newInstance(typeName);
                SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
                if ("java.util.List".equals(methodReturnType)) {
                    return ListConvert.convert(rowSet, obj);
                }
                return SetConvert.convert(rowSet, obj);
            } else if ("java.util.Map".equals(methodReturnType)) {
                throw new NotImplementedException();
            } else {
                SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
                Object obj = BeanReflectionUtil.newInstance(methodReturnType);
                return BeanConvert.convert(rowSet, obj);
            }
        } else {
            // 非查询方法
            // 判断是否是insert 语句
            ReturnGeneratedKey returnGeneratedKeyAnnotation = method.getAnnotation(ReturnGeneratedKey.class);
            if (returnGeneratedKeyAnnotation == null) {
                int retVal = jdbcTemplate.update(originSql, params);
                if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
                    return retVal;
                } else if ("java.lang.Boolean".equals(methodReturnType)) {
                    return retVal > 0;
                }
            } else {
                // 判断是否是 insert 语句
                boolean isInsertSql = originSql.trim().matches("(?i)insert([\\s\\S]*?)");
                if (isInsertSql) {
                    KeyHolder keyHolder = new GeneratedKeyHolder();
                    PreparedStatementCreator preparedStatementCreator = getPreparedStatementCreator(originSql, params, true);
                    jdbcTemplate.update(preparedStatementCreator, keyHolder);
                    if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
                        return keyHolder.getKey().intValue();
                    } else if ("java.lang.Long".equals(methodReturnType) || "long".equals(methodReturnType)) {
                        return keyHolder.getKey().longValue();
                    }
                    logger.error(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
                    throw new IllegalArgumentException(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
                } else {
                    logger.error(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey: sql语句为:" + originSql);
                    throw new IllegalStateException(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey: sql语句为:" + originSql);
                }
            }
        }
        return null;
    }
}

由此一个简单 ORM 就实现了,其实实现 ORM 并不难,难的是细心处理各种可能的 Bug

在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值