最近在公司实习过程中,TL提出一个需求,要求在不使用Hibernate的情况下实现一个比较通用的DAO框架,使用JDBCTemplate作为数据库sql语句的执行工具。在参考了CloudStack 3.0.2的相关源代码后,我自己实现了一个简化版的DAO框架。结果后来,TL又说改用Python开发,遗憾地把这些东西留作纪念吧。
简单的类图参见连接http://pan.baidu.com/share/link?shareid=115118&uk=3592520259
环境为MyEclipse8.5+Spring2.5,使用jar为asm-3.3.1、cglib-2.2、mysql-connector
1,编程思想
本质上是将某些通用的API,如最基础的CRUD直接通过泛型类来实现。唯一的比较难处理的就是Update时,哪些属性需要更新,这可以拦截通过CGLIB类库实现对setter方法的拦截并记录被改变的属性。
2,代码
核心类BaseDaoImpl
package DataBaseDemo.daoimpl;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.sql.DataSource;
import net.sf.cglib.proxy.Enhancer;
import org.apache.log4j.Logger;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import DataBaseDemo.interceptor.UpdateFactory;
import DataBaseDemo.util.DBUtils;
import DataBaseDemo.util.ModelRowMapper;
import com.mysql.jdbc.Statement;
public class BaseDaoImpl<T> {
Logger logger=Logger.getLogger(BaseDaoImpl.class);
//POJO类的实际类型
Class<T> entityType;
//简单地将POJO类名映射成数据库表名
String table;
public static JdbcTemplate jdbcTemplate;
public static PlatformTransactionManager transactionManager;
public static DefaultTransactionDefinition transactionDef;
@SuppressWarnings("unchecked")
BaseDaoImpl() {
DataSource datasource=DBUtils.configureDatasource();
jdbcTemplate = new JdbcTemplate(datasource);
transactionManager=new DataSourceTransactionManager(datasource);
transactionDef=new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRED);
Type t = getClass().getGenericSuperclass();
// 使用该语法后,BaseDaoImpl无法正常使用,只能通过子类调用它
// 用于获取实际输入的Model类型
if (t instanceof ParameterizedType) {
entityType = (Class<T>) ((ParameterizedType) t)
.getActualTypeArguments()[0];
} else if (((Class<?>) t).getGenericSuperclass() instanceof ParameterizedType) {
entityType = (Class<T>) ((ParameterizedType) ((Class<?>) t)
.getGenericSuperclass()).getActualTypeArguments()[0];
} else {
entityType = (Class<T>) ((ParameterizedType) ((Class<?>) ((Class<?>) t)
.getGenericSuperclass()).getGenericSuperclass())
.getActualTypeArguments()[0];
}
this.table = DBUtils.getTable(entityType);
}
@SuppressWarnings("unchecked")
public List<T> queryAll() {
String sql = "select * from " + table;
List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType));
return list;
}
/**
* 根据ID,查询一条记录并用实体包装
*
* @param id
* @return
*/
@SuppressWarnings("unchecked")
public T load(int id) {
String sql = "select * from " + table + " where id=" + id;
List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType));
return list.size() == 0 ? null : list.get(0);
}
/**
* 根据ID,删除指定表里的记录
*
* @param id
*/
public void delete(int id) {
String sql = "delete from " + table + " where id=" + id;
jdbcTemplate.execute(sql);
}
/**
* 根据ID更新实体数据到数据库
*
* @param entity
* @param id
*/
@SuppressWarnings("unchecked")
public void update(T entity, int id) {
assert Enhancer.isEnhanced(entity.getClass()) : "没有被拦截器监控到更新数据";
StringBuilder sql = new StringBuilder();
sql.append("update " + table + " set ");
System.out.println(entity.hashCode());
HashMap<String, Object> map = UpdateFactory.getChanges(entity.hashCode());
List<String> keys = new ArrayList<String>();
List<Object> values = new ArrayList<Object>();
Iterator iter = map.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
String key = (String) entry.getKey();
Object val = entry.getValue();
keys.add(key);
values.add(val);
}
for (int i = 0; i < keys.size(); i++) {
if (i == keys.size() - 1) {
sql.append(keys.get(i) + "=? ");
} else {
sql.append(keys.get(i) + "=?,");
}
}
sql.append("where id=?");
logger.info("更新语句:"+sql.toString());
values.add(id);
jdbcTemplate.update(sql.toString(), setParams(values.toArray()));
}
/**
* 插入实体,并返回数据库自增生成的ID
*
* @param entity
* @return
*/
@SuppressWarnings("unchecked")
public int insert(T entity) {
final StringBuilder sql = new StringBuilder();
sql.append("insert into " + table + "(");
HashMap<String, Object> map = getChangesForInsert(entity);
List<String> columns = new ArrayList<String>();
final List<Object> values = new ArrayList<Object>();
Iterator iter = map.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
String key = (String) entry.getKey();
Object val = entry.getValue();
columns.add(key);
values.add(val);
}
for (int i = 0; i < columns.size(); i++) {
if (i == columns.size() - 1) {
sql.append(columns.get(i) + ") values(");
} else {
sql.append(columns.get(i) + ",");
}
}
for (int i = 0; i < values.size(); i++) {
if (i == values.size() - 1) {
sql.append("?)");
} else {
sql.append("?,");
}
}
logger.info("插入语句:"+sql.toString());
KeyHolder key = new GeneratedKeyHolder();
final String insertSql=sql.toString();
jdbcTemplate.update(new PreparedStatementCreator() {
@Override
public PreparedStatement createPreparedStatement(Connection con)
throws SQLException {
// TODO Auto-generated method stub
//必须设置Statement.RETURN_GENERATED_KEYS才能进行返回ID
PreparedStatement ps = jdbcTemplate.getDataSource()
.getConnection().prepareStatement(insertSql,Statement.RETURN_GENERATED_KEYS);
for (int i = 0; i < values.size(); i++) {
ps.setObject(i + 1, values.get(i));
}
return ps;
}
}, key);
return key.getKey().intValue();
}
/**
* 插入实体并返回被插入的实体
*
* @param entity
* @return
*/
public T persist(final T entity) {
int id=insert(entity);
//直接通过connection进行提交同样无法成功
// transaction.commit();
logger.info("数据库返回的自增ID为:"+id);
T persisted=load(id);
return persisted;
}
/**
*
* @param params
* @return
*/
@SuppressWarnings("unchecked")
public List<T> query(SearchCriteria sc){
String where = sc.generateWhereClause();
StringBuilder sb = new StringBuilder("select * from "+sc.getTable());
sb.append(where);
logger.info("查询语句"+sb.toString());
logger.info("查询参数"+Arrays.toString(sc.generateParams()));
List<T> list = jdbcTemplate.query(sb.toString(), setParams(sc.generateParams()),new ModelRowMapper(entityType));
return list;
}
/**
* 返回DAO对应的数据库表的总记录数
* @return
*/
public int getTotalCount(){
String sql="select count(*) from "+table;
return jdbcTemplate.queryForInt(sql);
}
/**
* 以pagesize大小的页,返回第page页的数据
* @param page
* @param pagesize
* @return
*/
@SuppressWarnings("unchecked")
public List<T> getPage(int page,int pagesize){
if(page<0||pagesize<0){
throw new IllegalArgumentException("页码或页大小参数不合法");
}
String sql="select * from "+table+" limit "+page*pagesize+","+(page+1)*pagesize;
return jdbcTemplate.query(sql, new ModelRowMapper<T>(entityType));
}
/**
* 直接执行sql查询语句,param作为参数数组
* @param sql
* @param params
* 返回查询到的结果列表
* @return
*/
@SuppressWarnings("unchecked")
public List<T> executeRawSql(String sql,Object[] params){
return jdbcTemplate.query(sql, setParams(params)
, new ModelRowMapper<T>(entityType));
}
/**
* 设置查询用的参数列表
* @param params
* @return
*/
protected PreparedStatementSetter setParams(final Object[] params) {
return new PreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps) throws SQLException {
// TODO Auto-generated method stub
for (int i = 0; i < params.length; i++) {
ps.setObject(i + 1, params[i]);
}
}
};
}
/**
* 返回待插入实体上的所有非空属性值及属性名的Map
* @param entity
* @return
*/
protected HashMap<String, Object> getChangesForInsert(T entity){
Field[] fields = entityType.getDeclaredFields();
HashMap<String, Object> insertValues = new HashMap<String, Object>();
try {
for (Field field : fields) {
field.setAccessible(true);
//跳过id字段
if("id".equalsIgnoreCase(field.getName()))
continue;
Object value = field.get(entity);
if (value == null)
continue;
insertValues.put(field.getName(), value);
}
return insertValues;
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return null;
}
}
子类Dao实例:UserDaoImpl
package DataBaseDemo.daoimpl;
import java.util.List;
import DataBaseDemo.dao.UserDao;
import DataBaseDemo.model.UserVO;
public class UserDaoImpl extends BaseDaoImpl<UserVO> implements UserDao {
@Override
public UserVO queryUser() {
// TODO Auto-generated method stub
UserVO user=new UserVO();
return user;
}
//自定义的高级查询包装
public List<UserVO> listUsers(){
return queryAll();
}
}
核心拦截工厂对Model被修改的属性进行记录并通过CGLIB的接口进行拦截
package DataBaseDemo.interceptor;
import java.util.HashMap;
import net.sf.cglib.proxy.Callback;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.NoOp;
/**
* 使用UpdateFactory存放对象被改变的属性及其值
* 以对象的hashCode为key,值为被改变的HashMap
* @author Administrator
*
*/
public class UpdateFactory {
public static HashMap<Integer,HashMap<String,Object>> changes;
private static Enhancer enhancer;
//以字典的方式记录每个对象的改变属性值
static{
changes=new HashMap<Integer, HashMap<String,Object>>();
}
/**
* 根据对象的hashCode存储对象被改变的属性值
* @param hash
* @param key
* @param value
*/
public static void addChange(Integer hash,String key,Object value){
HashMap<String, Object> orginal=changes.get(hash);
if(orginal==null){
orginal=new HashMap<String, Object>();
orginal.put(key, value);
}else{
orginal.put(key, value);
}
changes.put(hash, orginal);
}
/**
* 以对象的hashCode取出对象的所有变更
* @param hash
* @return
*/
public static HashMap<String, Object> getChanges(Integer hash){
return changes.get(hash);
}
// 通过工厂生成对象,并产生拦截器,拦截set方法生成被改变的值Map
/**
* 根据对象class生成对应实例,并使它的修改能够被CGLIB拦截
*/
public static Object createVO(Class<?> clazz) {
enhancer = new Enhancer();
enhancer.setSuperclass(clazz);
Callback[] callbacks;
callbacks = new Callback[] { NoOp.INSTANCE, new UpdateInterceptor() };
enhancer.setCallbacks(callbacks);
enhancer.setCallbackFilter(new SetFilter());
return enhancer.create();
}
}
package DataBaseDemo.util;
import java.lang.reflect.Field;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import org.springframework.jdbc.core.RowMapper;
import DataBaseDemo.interceptor.UpdateFactory;
/**
* 使用包cglib和asm来创建对某一对象setters方法的拦截器
*
* @author Administrator
*
*/
public class ModelRowMapper<T> implements RowMapper {
/**
* @param args
*/
Class<?> clazz;
public ModelRowMapper(Class<?> clazz) {
this.clazz = clazz;
}
// RowMapper中直接通过field给字段设值,避免干扰set拦截器的使用
public static Object setValues(HashMap<String, Object> map, Object entity) {
Field[] fields = entity.getClass().getDeclaredFields();
try {
for (Field field : fields) {
Object value = map.get(field.getName());
if (value != null) {
field.setAccessible(true);
field.set(entity, value);
}
}
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return entity;
}
public void setValues(ResultSet rs, Object entity) {
Field[] fields = clazz.getDeclaredFields();
try {
for (Field field : fields) {
Object value = rs.getObject(field.getName());
field.setAccessible(true);
field.set(entity, value);
}
} catch (SQLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
@SuppressWarnings("unchecked")
@Override
public T mapRow(ResultSet rs, int rowNum) throws SQLException {
//通过更新工厂的静态方法创建类实例,使它被CGLIB监控
T entity = (T) UpdateFactory.createVO(clazz);
setValues(rs, entity);
return entity;
}
}
3,缺点
--目前的查询非常简单,需要进行优化
--无法支持事务管理,原因不明,进一步研究中
所有源码参照CloudStack3.0.2的相关代码编写