关闭

[置顶] BaseDao——基于dbutils实现万能的对象数据库访问

672人阅读 评论(0) 收藏 举报
分类:
BaseDao——基于dbutils实现万能对象的数据库访问

前言:dbutils的好用之处不必多说,但是在使用的过程中发现要对每个Bean写一个BeanDao,来实现Bean的访问,虽然一个Bean只需要写一次,但是对于复杂的Bean来说,其属性众多,写起来也是非常费时间的。所有我想写一个万能的BaseDao来一劳永逸。
1、BaseDao之前的BeanDao的添加实现:
	public boolean addTestBean(TestBean testBean){
		QueryRunner qr = new TxQueryRunner();
		String sql = "insert into group values(?,?,?,?,?,?,?)";
		Object params[] = {testBean.getTitle(),testBean.getCount(),testBean.getClick(),
				testBean.getLen(),testBean.getId(),testBean.getName(),testBean.getChild()};
		try {
			qr.update(sql, params);
			return true;
		} catch (SQLException e) {
			e.printStackTrace();
			return false;
		}
	}
2、BaseDao的添加实现是万能的,所有使用如下:
BaseDao bd = new BaseDao<TestBean>("test_bean", TestBean.class, new String[]{"id"});
TestBean bean = new TestBean();
bean.setId("123");
bean.setName("test");
bd.addObject(bean);
可见BaseDao还是非常好用的,只需实例化的时传入3个参数,然后就随意的增删改查了。第1个参数是表名,如果传入null,则会使用Bean的名称作表名(TestBean ==> test_bean),第2个参数为Bean的class,第3个参数为主键,可以为多键联合主键。
3、这里除了BaseDao的源码外还有一个工具类:MBUtils包含了BaseDao中用到的反射工具,都是刚开始写的,所以功能不是很多,但是还有个很好用的功能是可以根据Bean生成创建表的sql语句,方便Bean属性过多时使用非常方便,但是非常简易,以后再优化。
使用非常简单,一行代码:
MBUtils.generateCreateTable(null, TestBean.class,new String[]{"id"})
其中第1个参数是表名,如果传入null,则会使用Bean的名称作表名(TestBean ==> test_bean)。
效果如下:
CREATE TABLE 'test_bean'(
	'name' varchar(255),
	'click' varchar(255),
	'id' varchar(255),
	'count' int(32),
	'title' varchar(255),
	'child' varchar(255),
	'len' varchar(255),
	 PRIMARY KEY ('id')
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin;
4、依赖的jar包
c3p0-0.9.2-pre1.jar
commons-dbutils-1.4.jar
commons-logging-1.1.1.jar
itcast-tools-1.4.2.jar
mchange-commons-0.2.jar
mysql-connector-java-5.1.28-bin.jar
下载地址:
5、关于c3p0的使用,需要将c3p0-config.xml文件放到src目录下,文件名不可更改。然后配置一下内容:
<?xml version="1.0" encoding="UTF-8"?>
<c3p0-config>
	<!-- 这是默认配置信息 -->
	<default-config> 
		<!-- 连接四大参数配置 -->
		<property name="jdbcUrl">jdbc:mysql://localhost:3306/dbname</property>
		<property name="driverClass">com.mysql.jdbc.Driver</property>
		<property name="user">root</property>
		<property name="password">123456</property>
		<!-- 池参数配置 -->
		<property name="acquireIncrement">3</property>
		<property name="initialPoolSize">10</property>
		<property name="minPoolSize">2</property>
		<property name="maxPoolSize">10</property>
	</default-config>
</c3p0-config>
6、为了BaseDao的独立使用,BaseDao包含了MBUtils的部分函数,两个类的代码如下:
BaseDao.java
package com.match.sqlmodel;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.BeanHandler;
import org.apache.commons.dbutils.handlers.BeanListHandler;

import cn.itcast.jdbc.TxQueryRunner;
/**
 * 基于dbutils的万能的BaseDao
 * @author 亓根火柴
 * @date 2017-1-26
 * @param <T>
 */
public class BaseDao<T> {
	
	private QueryRunner qr = new TxQueryRunner();
	/**
	 * 表名
	 */
	private String table;
	/**
	 * 对象类型
	 */
	private Class cls;
	/**
	 * 主键(联合主键)
	 */
	private String[] primaryKeys;
	/**
	 * 隐藏默认构造器,强制初始化各参数
	 */
	private BaseDao(){};
	
	public BaseDao(String tableName,Class cls,String[] primaryKeys){
		if((tableName==null)||tableName.equals("")){
			table = getSqlName(cls.getSimpleName());
		}else{
			table = tableName;
		}
		this.cls = cls;
		this.primaryKeys = primaryKeys;
	}
	/**
	 * 添加对象到数据库中
	 * @param obj 对象
	 * @return 成功返回true
	 */
	public boolean addObject(Object obj){
		Map<String,Object> result = generateInsertParams(table,obj,cls);
		String sql = (String) result.get("sql");
		Object params[] = (Object[]) result.get("params");
		printLog(sql, params);
		try {
			qr.update(sql, params);
			return true;
		} catch (SQLException e) {
			e.printStackTrace();
			return false;
		}
	}
	/**
	 * 根据主键从数据库中删除该对象
	 * @param obj 对象
	 * @return 删除成功返回true
	 */
	public boolean deleteObject(Object obj){
		Map<String,Object> result = generateDeleteParams(table,obj,cls);
		String sql = (String) result.get("sql");
		Object params[] = (Object[]) result.get("params");
		printLog(sql, params);
		try {
			qr.update(sql,params);
			return true;
		} catch (SQLException e) {
			e.printStackTrace();
			return false;
		}
	}
	/**
	 * 根据主键修改数据库中的对象
	 * @param obj 修改对象
	 * @return 修改成功返回true
	 */
	public boolean editObject(Object obj){
		Map<String,Object> result = generateEditParams(table,obj,cls);
		String sql = (String) result.get("sql");
		Object params[] = (Object[]) result.get("params");
		printLog(sql, params);
		try {
			qr.update(sql, params);
			return true;
		} catch (SQLException e) {
			e.printStackTrace();
			return false;
		}
	}
	/**
	 * 根据主键查询对象
	 * @param obj 对象
	 * @return 查询出的对象
	 */
	public T queryObject(Object obj) {
		Map<String,Object> result = generateQueryParams(table,obj,cls);
		String sql = (String) result.get("sql");
		Object params[] = (Object[]) result.get("params");
		printLog(sql, params);
		try {
			return qr.query(sql, new BeanHandler<T>(cls), params);
		} catch (SQLException e) {
			e.printStackTrace();
			return null;
		}
	}
	/**
	 * 查询所有对象
	 * @return 对象列表
	 */
	public List<T> findAll(){
		String sql = "select * from "+table;
		try {
			return qr.query(sql, new BeanListHandler<T>(cls));
		} catch (SQLException e) {
			e.printStackTrace();
		}
		return null;
	}
	/**
	 * 生成插入语句和参数
	 * @param table 表名
	 * @param obj 对象
	 * @param cls 类型
	 * @return 语句和参数
	 */
	private Map<String, Object> generateInsertParams(String table, Object obj, Class cls) {
		Map<String, Object> data = new HashMap<String, Object>();
		//生成sql语句
		String[] fields = getAllFields(cls);
		StringBuffer sql = new StringBuffer("insert into "+table+"(");
		StringBuffer val = new StringBuffer("values(");
		for(String temp : fields){
			sql.append(temp+",");
			val.append("?,");
		}
		sql.deleteCharAt(sql.length()-1);
		val.deleteCharAt(val.length()-1);
		sql.append(") ");
		val.append(")");
		data.put("sql", sql.append(val).toString());
		//生成参数
		Object values[] = new Object[fields.length];
		for(int i = 0;i < fields.length;i++){
			values[i] = getValues(fields[i],obj,cls);
		}
		data.put("params", values);
		return data ;
	}
	/**
	 * 生成删除语句和参数
	 * @param table 表名
	 * @param obj 对象
	 * @param cls 类型
	 * @return 语句和参数
	 */
	private Map<String, Object> generateDeleteParams(String table, Object obj,Class cls) {
		Map<String, Object> data = new HashMap<String, Object>();
		//生成sql语句和参数
		StringBuffer sql = new StringBuffer("delete from "+table+" where ");
		Object values[] = new Object[primaryKeys.length];
		for(int i = 0;i < primaryKeys.length;i++){
			sql.append(primaryKeys[i] + "=? and ");
			values[i] = getValues(primaryKeys[i],obj,cls);
		}
		sql = sql.delete(sql.length()-4, sql.length());
		data.put("sql", sql.toString());
		data.put("params", values);
		return data;
	}
	/**
	 * 生成修改语句和参数
	 * @param table 表名
	 * @param obj 对象
	 * @param cls 类型
	 * @return 语句和参数
	 */
	private Map<String, Object> generateEditParams(String table, Object obj,Class cls) {
		Map<String, Object> data = new HashMap<String, Object>();
		//生成sql语句
		String[] fields = getAllFields(cls);
		List<String> pks = Arrays.asList(primaryKeys);
		StringBuffer sql = new StringBuffer("update "+table+" set ");
		for(String temp : fields){
			if(!pks.contains(temp)){
				sql.append(temp+"=?,");
			}
		}
		sql.deleteCharAt(sql.length()-1);
		sql.append(" where ");
		for(String temp : primaryKeys){
			sql.append(temp+"=? and ");
		}
		sql = sql.delete(sql.length()-4, sql.length());
		data.put("sql", sql.toString());
		//生成参数
		Object values[] = new Object[fields.length];
		int j = 0;
		for(int i = 0;i < fields.length;i++){
			if(!pks.contains(fields[i])){
				values[j] = getValues(fields[i],obj,cls);
				j++;
			}
		}
		for(String temp : primaryKeys){
			values[j] = getValues(temp,obj,cls);
			j++;
		}
		data.put("params", values);
		return data ;
	}
	/**
	 * 生成查询语句和参数
	 * @param table 表名
	 * @param obj 对象
	 * @param cls 类型
	 * @return 语句和参数
	 */
	private Map<String, Object> generateQueryParams(String table, Object obj,Class cls) {
		Map<String, Object> data = new HashMap<String, Object>();
		//生成sql语句和参数
		StringBuffer sql = new StringBuffer("select * from "+table+" where ");
		Object values[] = new Object[primaryKeys.length];
		for(int i = 0;i < primaryKeys.length;i++){
			sql.append(primaryKeys[i] + "=? and ");
			values[i] = getValues(primaryKeys[i],obj,cls);
		}
		sql = sql.delete(sql.length()-4, sql.length());
		data.put("sql", sql.toString());
		data.put("params", values);
		return data;
	}
	/**
	 * 根据set方法和Class.getFields(),获取该类的所有域,包括公共域和私有域
	 * @param cls 类型
	 * @return
	 */
	private String[] getAllFields(Class cls) {
		Field[] pubFields = cls.getFields();
		List<String> setFields = new ArrayList<String>();
		Method[] methods = cls.getMethods();
		for(Method method : methods){
			if(method.getName().startsWith("set")){
				String field = method.getName().substring(3).toLowerCase();
				setFields.add(field);
			}
		}
		for(Field temp:pubFields){
			if(!setFields.contains(temp.getName())){
				setFields.add(temp.getName());
			}
		}
		return setFields.toArray(new String[setFields.size()]);
	}
	/**
	 * 根据域名获取该对象域的值
	 * @param field 域名
	 * @param obj 对象
	 * @param cls 类型
	 * @return
	 */
	private Object getValues(String field, Object obj, Class cls) {
		Method[] methods = cls.getMethods();
		String firstChar = field.charAt(0)+"";
		firstChar = firstChar.toUpperCase();
		String methodName = "get"+firstChar+field.substring(1);
		for(Method method : methods){
			if(method.getName().equals(methodName)){
				try {
					return method.invoke(obj, null);
				} catch (IllegalAccessException e) {
					e.printStackTrace();
				} catch (IllegalArgumentException e) {
					e.printStackTrace();
				} catch (InvocationTargetException e) {
					e.printStackTrace();
				}
			}
		}
		return null;
	}

	private String getSqlName(String name){
		if((name == null)||(name.equals("")))return null;
		name = (name.charAt(0)+"").toLowerCase()+name.substring(1);
		int index = 0;
		while((index = contains(name))!=-1){
			char c = name.charAt(index);
			name = name.replaceFirst(c+"", ("_"+c).toLowerCase());
		}
		return name;
	}
	private int contains(String str){
		for(int i=0;i<str.length();i++){
			char c = str.charAt(i);
			if(('A'<=c)&&(c<='Z')){
				return i;
			}
		}return -1;
	}
	private void printLog(String sql, Object[] params) {
//		System.out.print("sql = "+sql+"\nparams = ");
//		for(int i=0;i<params.length-1;i++){
//			System.out.print(params[i]+",");
//		}System.out.println(params[params.length-1]);
	}
}
MBUtils.java
package com.match.sqlmodel;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
 * MySql & Bean Utils
 * @author 亓根火柴
 * @date 2017-1-26
 */
public class MBUtils {
	/**
	 * 简易地生成数据表创建语句
	 * @param table
	 * @param cls
	 * @param primaryKeys
	 * @return
	 */
	public static String generateCreateTable(String table,Class cls,String[] primaryKeys){
		StringBuffer sql = new StringBuffer();
		List<String> pks = Arrays.asList(primaryKeys);
		if((table == null)||table.equals("")){
			table = getSqlName(cls.getSimpleName());
		}
		sql.append("CREATE TABLE '"+table+"'(\n");
		String[] fields = getAllFields(cls);
		for(String field:fields){
			sql.append("\t'"+getSqlName(field)+"' ");
			String type = getFieldType(field,cls);
			if(pks.contains(field)){
				sql.append(type+" NOT NULL,\n");
			}else{
				sql.append(type+",\n");
			}
		}
		sql.append("\t PRIMARY KEY (");
		for(String pk:pks){
			sql.append("'"+pk+"',");
		}
		sql.deleteCharAt(sql.length()-1);
		sql.append(")\n) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin;");
		return sql.toString();
	}

	/**
	 * 根据set方法和Class.getFields(),获取该类的所有域,包括公共域和私有域
	 * @param cls 类型
	 * @return
	 */
	public static String[] getAllFields(Class cls) {
		Field[] pubFields = cls.getFields();
		List<String> setFields = new ArrayList<String>();
		Method[] methods = cls.getMethods();
		for(Method method : methods){
			if(method.getName().startsWith("set")){
				String field = method.getName().substring(3).toLowerCase();
				setFields.add(field);
			}
		}
		for(Field temp:pubFields){
			if(!setFields.contains(temp.getName())){
				setFields.add(temp.getName());
			}
		}
		return setFields.toArray(new String[setFields.size()]);
	}
	/**
	 * 根据域名获取该对象域的值
	 * @param field 域名
	 * @param obj 对象
	 * @param cls 类型
	 * @return
	 */
	public Object getValues(String field, Object obj, Class cls) {
		Method[] methods = cls.getMethods();
		String firstChar = field.charAt(0)+"";
		firstChar = firstChar.toUpperCase();
		String methodName = "get"+firstChar+field.substring(1);
		for(Method method : methods){
			if(method.getName().equals(methodName)){
				try {
					return method.invoke(obj, null);
				} catch (IllegalAccessException e) {
					e.printStackTrace();
				} catch (IllegalArgumentException e) {
					e.printStackTrace();
				} catch (InvocationTargetException e) {
					e.printStackTrace();
				}
			}
		}
		return null;
	}
	/**
	 * 获取域的类型和大小
	 * @param field
	 * @param cls
	 * @return
	 */
	private static String getFieldType(String field,Class cls) {
		Method[] methods = cls.getMethods();
		String methodName = "get"+getFirstUpperString(field);
		for(Method method:methods){
			if(method.getName().equals(methodName)){
				Class clss = method.getReturnType();
				String typeName = clss.getSimpleName();
				if(typeName.equals("String")){
					typeName = "varchar(255)";
				}else if(typeName.equals("Date")){
					typeName = "datatime()";
				}else if(typeName.equals("int")){
					typeName = "int(32)";
				}
				return typeName;
			}
		}
		return null;
	}
	/**
	 * 获取数据库用的名称(TestBean==>test_bean)
	 * @param name
	 * @return
	 */
	public static String getSqlName(String name){
		if((name == null)||(name.equals("")))return null;
		name = (name.charAt(0)+"").toLowerCase()+name.substring(1);
		int index = 0;
		while((index = contains(name))!=-1){
			char c = name.charAt(index);
			name = name.replaceFirst(c+"", ("_"+c).toLowerCase());
		}
		return name;
	}
	/**
	 * 返回字符串中第一个大写字母的索引
	 * @param str
	 * @return 没有大写字母返回-1
	 */
	public static int contains(String str){
		for(int i=0;i<str.length();i++){
			char c = str.charAt(i);
			if(('A'<=c)&&(c<='Z')){
				return i;
			}
		}return -1;
	}
	/**
	 * 将字符串的第一个字符大写
	 * @param str
	 * @return
	 */
	public static String getFirstUpperString(String str){
		String firstChar = str.charAt(0)+"";
		firstChar = firstChar.toUpperCase();
		return firstChar+str.substring(1);
	}
	/**
	 * 将字符串的第一个字符小写
	 * @param str
	 * @return
	 */
	public static String getFirstLowString(String str){
		String firstChar = str.charAt(0)+"";
		firstChar = firstChar.toLowerCase();
		return firstChar+str.substring(1);
	}
}




0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:50986次
    • 积分:818
    • 等级:
    • 排名:千里之外
    • 原创:29篇
    • 转载:5篇
    • 译文:0篇
    • 评论:25条
    访问统计
    为什么不能显示统计:https://www.revolvermaps.com
    最新评论