对CodeSmith + netTiers 生成DAL的一点补充


 


版本
CodeSmith 4.0
netTiers 2.0.1

背景
        最近在项目中使用CodeSmith + netTiers 生成数据访问层DAL,感觉效果很好,减少了大量的简单重复劳动。
不过在使用过程中发现CodeSmith提供的方法不能完全满足项目需要,主要体现在两个方面:
1、 Data.DataRepository.TableProvider.GetPaged方法无法输入带参数的条件,调用前必须进行SQL 拼接,
这样可能导致SQL 注入攻击。
2、 DataRepository.Provider.ExecuteDataSet 无法分页查询

为解决以上问题,我做了如下代码对生成的DAL进行了补充。这些代码可以在DAL外部使用,也可以修改netTiers
模板,内置到DAL中。


    /** <summary>
    /// 带参数的条件查询子句异常
    /// </summary>
    public class ParaWhereStringException : Exception
    {
        public ParaWhereStringException(String message)
            : base(message)
        {

        }
    }

    /** <summary>
    /// 带参数的条件查询子句
    /// </summary>
    public class ParaWhereString
    {
        enum T_STATE
        {
            Idle = 0,
            At   = 1,
            Str  = 2,
        }

        T_STATE m_State;
        int m_LastPos;
        int m_CurPos;
        String m_WhereString;
        List<String> m_Words = new List<string>();

        private void Clear()
        {
            m_State = T_STATE.Idle;
            m_LastPos = 0;
            m_CurPos = 0;
            m_WhereString = "";
            m_Words = new List<string>();
        }

        private void ChangeState(T_STATE curState)
        {
            m_State = curState;
            NewWord();
        }

        private void EndWord()
        {
            m_Words.Add(m_WhereString.Substring(m_LastPos, m_WhereString.Length - m_LastPos));
        }

        private void NewWord()
        {
            m_Words.Add(m_WhereString.Substring(m_LastPos, m_CurPos - m_LastPos));
            m_LastPos = m_CurPos;
        }

        private void StateMachine(char ch)
        {
            switch (m_State)
            {
                case T_STATE.Idle:
                    if (ch == '@')
                    {
                        ChangeState(T_STATE.At);
                    }
                    else if (ch == '/'')
                    {
                        ChangeState(T_STATE.Str);
                    }

                    break;
                case T_STATE.At:
                    if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_')
                    {
                        break;
                    }

                    if (ch >= '0' && ch <= '9' && m_CurPos - m_LastPos > 1)
                    {
                        break;
                    }

                    if (ch == '/'')
                    {
                        ChangeState(T_STATE.Str);
                    }
                    else
                    {
                        ChangeState(T_STATE.Idle);
                    }

                    break;

                case T_STATE.Str:
                    if (ch == '/'')
                    {
                        m_CurPos++;

                        if (m_WhereString[m_CurPos] == '/'')
                        {
                            break;
                        }
                        else
                        {
                            ChangeState(T_STATE.Idle);
                        }
                    }
                    break;
            }

            if (m_CurPos == m_WhereString.Length - 1)
            {
                //无论任何状态,只要到了最后一个字符,结束状态机
                EndWord();
                return;
            }
        }

        private void SplitWhereString(String whereString)
        {
            System.Diagnostics.Debug.Assert(whereString != null);

            m_State = T_STATE.Idle;

            m_LastPos = 0;
            m_CurPos = 0;

            while (m_CurPos < whereString.Length)
            {
                StateMachine(whereString[m_CurPos]);
                m_CurPos++;
            }
        }

        private String GetParaValue(String paraName, object value)
        {
            if ((value is int) || (value is uint) ||
                (value is short) || (value is ushort) ||
                (value is sbyte) || (value is byte) ||
                (value is long) || (value is ulong) ||
                (value is float) || (value is double)
                )
            {
                return value.ToString();
            }

            if ((value is string) || (value is char))
            {
                return "'" + value.ToString().Replace("'", "''") + "'";
            }

            if (value is DateTime)
            {
                DateTime d = (DateTime)value;

                return "'" + d.ToString("yyyy-MM-dd HH:mm:ss") + "'";
            }

            if (value == DBNull.Value)
            {
                return "NULL";
            }

            throw new ParaWhereStringException(String.Format("invalid type of para={0}!",
                paraName));
        }

        /** <summary>
        /// 根据参数获取条件子句
        /// </summary>
        /// <param name="whereString">
        /// 带参数的条件子句,如
        /// "Price>@MinPrice and Price < @MaxPrice"
        /// </param>
        /// <param name="parameters">参数列表</param>
        /// <returns>获取实际的条件子句,如 "Price > 10 and Price < 100"</returns>
        public String GetWhereString(String whereString, List<SqlParameter> parameters)
        {
            if (parameters == null)
            {
                return whereString;
            }

            Clear();

            m_WhereString = whereString;
            SplitWhereString(whereString);

            Hashtable table = new Hashtable();

            foreach (SqlParameter para in parameters)
            {
                if (para.Value == null)
                {
                    table['@' + para.ParameterName.ToLower()] = DBNull.Value;
                }
                else
                {
                    table['@' + para.ParameterName.ToLower()] = para.Value;
                }
            }

            StringBuilder whereStr = new StringBuilder();

            foreach (String str in m_Words)
            {
                if (str.Length > 0)
                {
                    if (str[0] == '@')
                    {
                        object value = table[str.ToLower().Trim()];
                        if (value == null)
                        {
                            throw new ParaWhereStringException(String.Format("para={0} does not in parameters!",
                                str));
                        }

                        whereStr.Append(GetParaValue(str, value));
                        continue;
                    }
                }

                whereStr.Append(str);
            }

            return whereStr.ToString();
        }

    }

    /** <summary>
    /// 数据存储扩展
    /// </summary>
    public class DataRepositoryEx
    {
        /** <summary>
        /// 获取分页的查询结果,查询语句必须是
        /// Select 形式的,不能处理存储过程
        /// </summary>
        /// <param name="fields">where 子句前面的部分,不能有top关键字 如 “Price,ReleaseTime, RecName as Address”</param>
        /// <param name="tableName">要查询的表名</param>
        /// <param name="condition">带参数的 where子句,不包括where关键字 如 “Price > @MinPrice and Price < @MaxPrice”</param>
        /// <param name="parameters">where子句的参数</param>
        /// <param name="orderBy">order by 子句部分, 如果有Group by 也可以写在这里 如“order by ReleaseTime ASC”</param>
        /// <param name="pageNo">页面号,从0开始编号</param>
        /// <param name="pageLength">页面长度,即每页面记录数</param>
        /// <param name="count">输出查询结果的总数</param>
        /// <returns>以数据表形式返回查询结果集</returns>
        static public DataTable SelectPaged(String fields, String tableName,
            String condition, List<SqlParameter> parameters, String orderBy, int pageNo, int pageLength, out int count)
        {
            System.Diagnostics.Debug.Assert(pageNo >= 0);
            System.Diagnostics.Debug.Assert(pageLength > 0);

            ParaWhereString paraWhereStr = new ParaWhereString();
            String sqlCond = paraWhereStr.GetWhereString(condition, parameters);

            String sql;

            if (condition == null)
            {
                condition = "";
            }

            if (condition == "")
            {
                sql = String.Format("select count(*) cnt from {0}", tableName);
            }
            else
            {
                sql = String.Format("select count(*) cnt from {0} where {1}", tableName, sqlCond);
            }

            DataSet ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);

            count = (int)ds.Tables[0].Rows[0]["cnt"];

            int upperBound = (pageNo + 1) * pageLength;

            int lowerBound = pageNo * pageLength;

            if (condition == "")
            {
                sql = String.Format("select top {0} {1} from {2} ", upperBound, fields, tableName);
            }
            else
            {
                sql = String.Format("select top {0} {1} from {2} where {3} ", upperBound, fields, tableName, sqlCond);
            }

            if (orderBy != "" && orderBy != null)
            {
                sql += orderBy;
            }

            ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);

            if (ds.Tables[0].Rows.Count <= lowerBound)
            {
                ds.Tables[0].Clear();
            }
            else
            {
                for (int i = 0; i < lowerBound; i++)
                {
                    ds.Tables[0].Rows.RemoveAt(0);
                }
            }

            return ds.Tables[0];
        }
    }

ParaWhereString  类用于将带参数的条件子句转换为不带参数的条件子句,供GetPaged,GetAll两个方法使用。这个类是一个通用的类,也可以用于
其他应用中获取带参数的条件子句的最终转换后的条件子句。

DataRepositoryEx 类提供分查询的方法。

ParaWhereString 调用示例
            ParaWhereString paraWhereString = new ParaWhereString();

            string whereString = "price>@minPrice and price <= @maxPrice and str like '%adb''@aaa dsafj'";

            List<SqlParameter> paras = new List<SqlParameter>();
               paras.Add(new SqlParameter("minPrice", 100));
               paras.Add(new SqlParameter("MaxPrice", 1000));


            String sql = paraWhereString.GetWhereString(whereString, paras);

            Console.WriteLine(sql);输出结果:
price>100 and price <= 1000 and str like '%adb''@aaa dsafj'

DataRepositoryEx 调用示例

用于测试的表结构

use Test

   GO

Create Table Test
(
id int identity (1,1) not null,
a int

)


向表Test中插入若干条连续的记录

查询分页数据示例

 


        int count;

        List<SqlParameter> paras = new List<SqlParameter>();
        paras.Add(new SqlParameter("min", 3));
        paras.Add(new SqlParameter("max", 30));

        DataTable table = SecUser.Cert.BLL.DataRepositoryEx.SelectPaged("id, a", "test..test", "id >= @min and id < @max",
            paras, "order by id DESC", 0, 10, out count);

        Response.Write(String.Format("Count={0}", count));

        foreach(DataRow row in table.Rows)
        {
            Response.Write(String.Format("</p>{0}", row["id"]));
        }

查询结果:

Count=27

29

28

27

26

25

24

23

22

21

20


本文来自CSDN博客,转载请标明出处:http://blog.csdn.net/eaglet/archive/2007/07/26/1709671.aspx

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值