事务管理的实现
TransactionManager 类
public class TransactionManager {
//开启事务
public static void beginTrans() throws SQLException {
ConnUtil.getConn().setAutoCommit(false);
}
//提交事务
public static void commit() throws SQLException {
Connection conn = ConnUtil.getConn();
conn.commit();
ConnUtil.closeConn();
}
//回滚事务
public static void rollback() throws SQLException {
Connection conn = ConnUtil.getConn();
conn.commit();
ConnUtil.closeConn();
}
}
OpenSessionInViewFilter 类
@WebFilter("*.do")
public class OpenSessionInViewFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
try {
TransactionManager.beginTrans();
filterChain.doFilter(servletRequest,servletResponse);
TransactionManager.commit();
} catch (SQLException e) {
e.printStackTrace();
try {
TransactionManager.rollback();
} catch (SQLException ex) {
ex.printStackTrace();
}
}
}
@Override
public void destroy() {
}
}
ConnUtil 工具类
public class ConnUtil {
private static ThreadLocal<Connection> threadLocal = new ThreadLocal<>();
public static final String DRIVER = "com.mysql.jdbc.Driver" ;
public static final String URL = "jdbc:mysql://localhost:3306/fruitdb?useUnicode=true&characterEncoding=utf-8&useSSL=false";
public static final String USER = "root";
public static final String PWD = "123456";
private static Connection createConn(){
try {
//1.加载驱动
Class.forName(DRIVER);
//2.通过驱动管理器获取连接对象
return DriverManager.getConnection(URL, USER, PWD);
} catch (ClassNotFoundException | SQLException e) {
e.printStackTrace();
}
return null ;
}
public static Connection getConn(){
Connection conn = threadLocal.get();
if (conn == null){
conn = createConn();
threadLocal.set(conn);
}
return threadLocal.get();
}
public static void closeConn() throws SQLException {
Connection conn = threadLocal.get();
if (conn == null){
return;
}
if (!conn.isClosed()){
conn.close();
threadLocal.set(null);
}
}
}
DAOException 自定义异常
public class DAOException extends RuntimeException{
public DAOException(String msg){
super(msg);
}
}
BaseDAO<T> 类
public abstract class BaseDAO<T> {
protected Connection conn ;
protected PreparedStatement psmt ;
protected ResultSet rs ;
//T的Class对象
private Class entityClass ;
public BaseDAO(){
//getClass() 获取Class对象,当前我们执行的是new FruitDAOImpl() , 创建的是FruitDAOImpl的实例
//那么子类构造方法内部首先会调用父类(BaseDAO)的无参构造方法
//因此此处的getClass()会被执行,但是getClass获取的是FruitDAOImpl的Class
//所以getGenericSuperclass()获取到的是BaseDAO的Class
Type genericType = getClass().getGenericSuperclass();
//ParameterizedType 参数化类型
Type[] actualTypeArguments = ((ParameterizedType) genericType).getActualTypeArguments();
//获取到的<T>中的T的真实的类型
Type actualType = actualTypeArguments[0];
try {
entityClass = Class.forName(actualType.getTypeName());
} catch (ClassNotFoundException e) {
e.printStackTrace();
throw new DAOException("BaseDAO 构造方法出错了,可能的原因是没有指定<>中的类型");
}
}
//获取连接
protected Connection getConn(){
return ConnUtil.getConn();
}
protected void close(ResultSet rs , PreparedStatement psmt , Connection conn){
}
//给sql语句中的占位符传值
private void setParams(PreparedStatement psmt , Object... params) throws SQLException {
if(params!=null && params.length>0){
for (int i = 0; i < params.length; i++) {
psmt.setObject(i+1,params[i]);
}
}
}
//执行更新,返回影响行数
protected int executeUpdate(String sql , Object... params){
boolean insertFlag = false;
insertFlag = sql.trim().toUpperCase().startsWith("INSERT");
conn = getConn();
try {
if (insertFlag) {
psmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
} else {
psmt = conn.prepareStatement(sql);
}
setParams(psmt, params);
int count = psmt.executeUpdate();
if (insertFlag) {
rs = psmt.getGeneratedKeys();
if (rs.next()) {
return ((Long) rs.getLong(1)).intValue();
}
}
return 0;
} catch (SQLException e) {
e.printStackTrace();
throw new DAOException("BaseDAO executeUpdate出错了");
}
}
//通过反射技术给obj对象的property属性赋propertyValue值
private void setValue(Object obj , String property , Object propertyValue) throws NoSuchFieldException, IllegalAccessException {
Class clazz = obj.getClass();
//获取property这个字符串对应的属性名 , 比如 "fid" 去找 obj对象中的 fid 属性
Field field = clazz.getDeclaredField(property);
if (field != null) {
field.setAccessible(true);
field.set(obj, propertyValue);
}
}
//执行复杂查询,返回例如统计结果
protected Object[] executeComplexQuery(String sql , Object... params){
conn = getConn();
try {
psmt = conn.prepareStatement(sql);
setParams(psmt, params);
rs = psmt.executeQuery();
//通过rs可以获取结果集的元数据
//元数据:描述结果集数据的数据 , 简单讲,就是这个结果集有哪些列,什么类型等等
ResultSetMetaData rsmd = rs.getMetaData();
//获取结果集的列数
int columnCount = rsmd.getColumnCount();
Object[] columnValueArr = new Object[columnCount];
//6.解析rs
if (rs.next()) {
for (int i = 0; i < columnCount; i++) {
Object columnValue = rs.getObject(i + 1); //33 苹果 5
columnValueArr[i] = columnValue;
}
return columnValueArr;
}
} catch (SQLException e) {
e.printStackTrace();
throw new DAOException("BaseDAO executeComplexQuery出错了");
}
return null;
}
//执行查询,返回单个实体对象
protected T load(String sql , Object... params){
conn = getConn();
try {
psmt = conn.prepareStatement(sql);
setParams(psmt, params);
rs = psmt.executeQuery();
//通过rs可以获取结果集的元数据
//元数据:描述结果集数据的数据 , 简单讲,就是这个结果集有哪些列,什么类型等等
ResultSetMetaData rsmd = rs.getMetaData();
//获取结果集的列数
int columnCount = rsmd.getColumnCount();
//6.解析rs
if (rs.next()) {
T entity = (T) entityClass.newInstance();
for (int i = 0; i < columnCount; i++) {
String columnName = rsmd.getColumnName(i + 1); //fid fname price
Object columnValue = rs.getObject(i + 1); //33 苹果 5
setValue(entity, columnName, columnValue);
}
return entity;
}
} catch (Exception e) {
e.printStackTrace();
throw new DAOException("BaseDAO load出错了");
}
return null;
}
//执行查询,返回List
protected List<T> executeQuery(String sql , Object... params){
List<T> list = new ArrayList<>();
conn = getConn();
try {
psmt = conn.prepareStatement(sql);
setParams(psmt, params);
rs = psmt.executeQuery();
//通过rs可以获取结果集的元数据
//元数据:描述结果集数据的数据 , 简单讲,就是这个结果集有哪些列,什么类型等等
ResultSetMetaData rsmd = rs.getMetaData();
//获取结果集的列数
int columnCount = rsmd.getColumnCount();
//6.解析rs
while (rs.next()) {
T entity = (T) entityClass.newInstance();
for (int i = 0; i < columnCount; i++) {
String columnName = rsmd.getColumnName(i + 1); //fid fname price
Object columnValue = rs.getObject(i + 1); //33 苹果 5
setValue(entity, columnName, columnValue);
}
list.add(entity);
}
} catch (Exception e) {
e.printStackTrace();
throw new DAOException("BaseDAO executeQuery出错了");
}
return list;
}
}
DispatcherServlet 类
@WebServlet("*.do")
public class DispatcherServlet extends ViewBaseServlet {
private BeanFactory beanFactory;
public DispatcherServlet() {
}
public void init() throws ServletException {
super.init();
beanFactory = new ClassPathXmlApplicationContext();
}
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
//设置编码
//request.setCharacterEncoding("utf-8");
//假设url是:http://localhost:8080/pro15/hello.do
//那么servletPath是: /hello.do
//我的思路是:
///第一步:/hello.do -> hello 或者 /fruit.do -> fruit
//第二步:hello -> HelloController 或者 fruit -> FruitController
String servletPath = request.getServletPath();
servletPath = servletPath.substring(1);
int lastDotIndex = servletPath.lastIndexOf(".do");
servletPath = servletPath.substring(0, lastDotIndex);
Object controllerBeanObj = beanFactory.getBean(servletPath);
String operate = request.getParameter("operate");
if (StringUtil.isEmpty(operate)) {
operate = "index";
}
try {
Method[] methods = controllerBeanObj.getClass().getDeclaredMethods();
for (Method method : methods) {
if (operate.equals(method.getName())) {
//1.统一获取请求参数
//1-1 获取当前方法的参数,返回参数数组
Parameter[] parameters = method.getParameters();
//1-2 parameterValues 用来存放参数的值
Object[] parameterValues = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) {
Parameter parameter = parameters[i];
String parameterName = parameter.getName();
//如果参数名是request,response,session,那么就不是通过请求中获取参数的方式
if ("request".equals(parameterName)) {
parameterValues[i] = request;
} else if ("response".equals(parameterName)) {
parameterValues[i] = response;
} else if ("session".equals(parameterName)) {
parameterValues[i] = request.getSession();
} else {
//从请求中获取参数值
String parameterValue = request.getParameter(parameterName);
String typename = parameter.getType().getName();
Object parameterObj = parameterValue;
if (parameterObj != null) {
if ("java.lang.Integer".equals(typename)) {
parameterObj = Integer.parseInt(parameterValue);
}
}
parameterValues[i] = parameterObj; // 存储的是"2" 而不是 2
}
}
method.setAccessible(true);
//2.controller组件中的方法调用
//找到和operate同名的方法,那么通过反射技术调用它
Object returnObj = method.invoke(controllerBeanObj, parameterValues);
//3.视图处理
String methodReturnStr = (String) returnObj;
if (methodReturnStr.startsWith("redirect:")) { //比如:redirect:fruit.do
String redirectStr = methodReturnStr.substring("redirect:".length());
response.sendRedirect(redirectStr);
} else {
super.processTemplate(methodReturnStr, request, response); //比如 : "edit"
}
}
}
// }else {
// throw new RuntimeException("operate的值非法!");
// }
} catch (Exception e) {
e.printStackTrace();
throw new DispatcherServletException("DispatcherServlet出错了...");
}
}
}
//常见错误:IllegalArgumentException: argument type mismatch