Go封装操作数据库
package dao
import (
"bytes"
"database/sql"
"fmt"
"lzy/framework/util"
_ "mysql"
"reflect"
"strconv"
"strings"
)
type MultiFunction struct {
Model interface{}
ResultSet map[int]map[string]string
Rule map[string]string
Field map[string]string
KV map[string]string
cv []interface{}
Ids []string
IsQueryAll bool
pageSize int
nowPage int
total int
totalPage int
IsPagin bool
OrderBy string
OrderByType string
tableN string
sql string
IsSort bool
tag int64
IsDistinct bool
Id string
mysql bool
DbAddr string
DbType string
beginRow int
Num int64
}
func NewMultiFunction(o interface{}) *MultiFunction {
return &MultiFunction{
Model: o,
Rule: make(map[string]string),
Field: make(map[string]string),
KV: make(map[string]string),
pageSize: 10,
nowPage: 1,
IsPagin: true,
IsDistinct: false,
OrderBy: "Id",
OrderByType: "desc",
IsQueryAll: false,
IsSort: true,
total: -1,
Id: "Id",
mysql: false,
tag: 0,
beginRow: -1,
tableN: reflect.ValueOf(o).Elem().FieldByName("tableN").String(),
Num: 0,
}
}
func (m *MultiFunction) SetSql(sql string) {
m.mysql = true
m.sql = sql
}
func (m *MultiFunction) GetPageSize() int {
return m.pageSize
}
func (m *MultiFunction) SetPageSize(pageSize int) {
m.pageSize = pageSize
}
func (m *MultiFunction) SetNowPage(n int) {
if n < 0 {
m.nowPage = 0
}
if n > m.GetTotalPage() {
m.nowPage = m.totalPage
}
m.nowPage = n
}
func (m *MultiFunction) GetNowPage() int {
return m.nowPage
}
func (m *MultiFunction) GetTotal() int {
if m.total == -1 {
ex := NewExecute()
ex.Count(m)
}
return m.total
}
func (m *MultiFunction) GetTotalPage() int {
m.total = m.GetTotal()
if m.total%m.pageSize == 0 {
m.totalPage = m.total / m.pageSize
return m.totalPage
} else {
m.totalPage = m.total/m.pageSize + 1
return m.totalPage
}
}
func (m *MultiFunction) GetBeginRow() int {
if m.beginRow < 0 {
return m.pageSize * (m.nowPage - 1)
}
return m.beginRow
}
func (m *MultiFunction) SetBeginRow(n int) {
m.beginRow = n
}
type Execute struct{}
func NewExecute() *Execute {
return &Execute{}
}
func (ex *Execute) Insert(mf *MultiFunction) *util.Message {
var db *sql.DB
var res sql.Result
var err error
var tx *sql.Tx
msg := util.NewMessage()
db, err = GetDBConnect(mf)
defer db.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx, err = db.Begin()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if !mf.mysql {
mf.tag = 0
GetInserSql(mf)
}
res, err = tx.Exec(mf.sql, mf.cv[0:]...)
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx.Commit()
mf.Num, _ = res.RowsAffected()
return msg
}
func (ex *Execute) Delete(mf *MultiFunction) *util.Message {
var db *sql.DB
var res sql.Result
var err error
var tx *sql.Tx
msg := util.NewMessage()
db, err = GetDBConnect(mf)
defer db.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if !mf.mysql {
mf.tag = 0
GetDeleteSql(mf)
}
tx, err = db.Begin()
util.CheckErr(err, msg)
if err != nil {
return msg
}
fmt.Println("finally sql---->" + mf.sql)
res, err = tx.Exec(mf.sql, mf.cv[0:]...)
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx.Commit()
mf.Num, _ = res.RowsAffected()
return msg
}
func (ex *Execute) Update(mf *MultiFunction) *util.Message {
var db *sql.DB
var res sql.Result
var err error
var tx *sql.Tx
msg := util.NewMessage()
db, err = GetDBConnect(mf)
defer db.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx, err = db.Begin()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if !mf.mysql {
mf.tag = 0
GetUpdateSql(mf)
}
res, err = tx.Exec(mf.sql, mf.cv[0:]...)
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx.Commit()
mf.Num, _ = res.RowsAffected()
return msg
}
func (ex *Execute) QueryAllOrByCondition(mf *MultiFunction) *util.Message {
var db *sql.DB
var rls *sql.Rows
var err error
var tx *sql.Tx
msg := util.NewMessage()
db, err = GetDBConnect(mf)
defer db.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx, err = db.Begin()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if !mf.mysql {
mf.tag = 0
GetQueryAllOrByConditionSql(mf)
}
rls, err = tx.Query(mf.sql, mf.cv[0:]...)
defer rls.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
handleRls(rls, mf, msg)
return msg
}
func (ex *Execute) Count(mf *MultiFunction) *util.Message {
var db *sql.DB
var rls *sql.Rows
var err error
var tx *sql.Tx
msg := util.NewMessage()
db, err = GetDBConnect(mf)
defer db.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
tx, err = db.Begin()
defer tx.Commit()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if !mf.mysql {
mf.tag = 0
GetCountSql(mf)
}
rls, err = tx.Query(mf.sql, mf.cv[0:]...)
defer rls.Close()
util.CheckErr(err, msg)
if err != nil {
return msg
}
if rls.Next() {
rls.Scan(&mf.total)
}
return msg
}
func handleRls(rls *sql.Rows, mf *MultiFunction, msg *util.Message) {
column, _ := rls.Columns()
values := make([][]byte, len(column))
scans := make([]interface{}, len(column))
for i := range values {
scans[i] = &values[i]
}
results := make(map[int]map[string]string)
i := 0
for rls.Next() {
if err := rls.Scan(scans...); err != nil {
util.CheckErr(err, msg)
return
}
row := make(map[string]string)
for k, v := range values {
key := column[k]
row[key] = string(v)
}
results[i] = row
i++
}
mf.ResultSet = results
return
}
type CRUD interface {
Insert(*MultiFunction) *util.Message
QueryAllOrByCondition(*MultiFunction) *util.Message
Delete(*MultiFunction) *util.Message
Update(*MultiFunction) *util.Message
}
func GetCountSql(mf *MultiFunction) {
sql := bytes.NewBufferString("select count(*) as total from ")
sql.WriteString(mf.tableN)
l := reflect.ValueOf(mf.Model).Elem().NumField()
l2 := len(mf.Ids)
mf.cv = make([]interface{}, l+l2+2)
sql.WriteString(appendQueryCondition(mf))
mf.cv = mf.cv[0:mf.tag]
fmt.Println("select coun sql ---> ", sql.String())
fmt.Println("query values --------> ", mf.cv)
mf.sql = sql.String()
return
}
func GetDeleteSql(mf *MultiFunction) {
if !util.IsNotEmpty(mf.tableN) {
fmt.Println("mf.tableN is empty...")
return
}
sql := bytes.NewBufferString("delete from ")
sql.WriteString(mf.tableN)
l := reflect.ValueOf(mf.Model).Elem().NumField()
mf.cv = make([]interface{}, l)
sql.WriteString(appendQueryCondition(mf))
mf.cv = mf.cv[0:mf.tag]
fmt.Println("delete sql ---> ", sql.String())
mf.sql = sql.String()
fmt.Println("delete value ---> ", mf.cv)
return
}
func GetQueryAllOrByConditionSql(mf *MultiFunction) {
if !util.IsNotEmpty(mf.tableN) {
panic("mf.tableN is empty...")
return
}
sql := bytes.NewBufferString("select ")
if mf.IsDistinct {
sql.WriteString(" distinct ")
}
if len(mf.Field) <= 0 {
sql.WriteString(appendFieldSql(mf.Model))
} else {
for _, v := range mf.Field {
sql.WriteString(v)
sql.WriteString(",")
}
}
sql2 := sql.String()
l := len(sql2)
sql2 = sql2[0 : l-1]
sql = bytes.NewBufferString(sql2)
sql.WriteString(" from ")
sql.WriteString(mf.tableN)
l2 := reflect.ValueOf(mf.Model).Elem().NumField()
if !mf.IsQueryAll {
l3 := len(mf.Ids)
if l3 > 0 {
mf.cv = make([]interface{}, l2+l3+2)
} else {
mf.cv = make([]interface{}, l2+2)
}
sql.WriteString(appendQueryCondition(mf))
} else {
mf.cv = make([]interface{}, 2)
}
if mf.IsSort {
sql.WriteString(" order by ")
sql.WriteString(mf.OrderBy)
sql.WriteString(" " + mf.OrderByType)
}
if mf.IsPagin {
sql.WriteString(" limit ? ,? ")
mf.cv[mf.tag] = mf.GetBeginRow()
mf.tag++
mf.cv[mf.tag] = mf.GetPageSize()
mf.tag++
}
fmt.Println("query sql ---> ", sql.String())
mf.cv = mf.cv[0:mf.tag]
fmt.Println("query values ---> ", mf.cv)
mf.sql = sql.String()
return
}
func GetInserSql(mf *MultiFunction) {
if !util.IsNotEmpty(mf.tableN) {
fmt.Println("mf.tableN is empty...")
return
}
sql := bytes.NewBufferString("insert into ")
sql.WriteString(mf.tableN)
sql.WriteString("(")
field := bytes.NewBufferString("")
o := mf.Model
o2 := reflect.ValueOf(o).Elem()
o3 := o2.Type()
l := o2.NumField()
mf.cv = make([]interface{}, l)
for i := 0; i < l; i++ {
f := o3.Field(i)
fn := f.Name
fm := o2.FieldByName(fn)
if fn != "tableN" {
if util.IsStringT(fm.Type()) {
fv := fm.String()
if util.IsNotEmpty(fv) {
sql.WriteString(fn)
sql.WriteString(",")
field.WriteString("?,")
mf.cv[mf.tag] = fv
mf.tag++
}
}
if util.IsIntT(fm.Type()) {
fv := fm.Int()
if string(fv) != "" {
sql.WriteString(fn)
sql.WriteString(",")
field.WriteString("?,")
mf.cv[mf.tag] = fv
mf.tag++
}
}
}
}
sql = bytes.NewBufferString(sql.String()[0 : len(sql.String())-1])
sql.WriteString(")values(")
sql.WriteString(field.String()[0 : len(field.String())-1])
sql.WriteString(")")
fmt.Println("insert sql ---> ", sql.String())
mf.cv = mf.cv[0:mf.tag]
fmt.Println("insert values ---> ", mf.cv)
mf.sql = sql.String()
return
}
func GetUpdateSql(mf *MultiFunction) {
if !util.IsNotEmpty(mf.tableN) {
fmt.Println("mf.tableN is empty...")
return
}
sql := bytes.NewBufferString("update ")
sql.WriteString(mf.tableN)
sql.WriteString(" set ")
o := mf.Model
o2 := reflect.ValueOf(o).Elem()
mf.cv = make([]interface{}, len(mf.KV)+len(mf.Ids)+len(mf.Rule))
for _, v := range mf.KV {
f := o2.FieldByName(v)
if util.IsStringT(f.Type()) {
fv := f.String()
sql.WriteString(v)
sql.WriteString("=?,")
mf.cv[mf.tag] = fv
mf.tag++
}
if util.IsIntT(f.Type()) {
fv := f.Int()
sql.WriteString(v)
sql.WriteString("=?,")
mf.cv[mf.tag] = fv
mf.tag++
}
}
sql2 := sql.String()
sql2 = sql2[0 : len(sql2)-1]
sql = bytes.NewBufferString(sql2)
sql.WriteString(appendQueryCondition(mf))
fmt.Println("update sql ---> ", sql.String())
mf.cv = mf.cv[0:mf.tag]
fmt.Println("update values ---> ", mf.cv)
mf.sql = sql.String()
return
}
func appendFieldSql(o interface{}) string {
sql := bytes.NewBufferString("")
o2 := reflect.ValueOf(o).Elem().Type()
l := o2.NumField()
for i := 0; i < l; i++ {
f := o2.Field(i)
fn := f.Name
if fn != "tableN" {
sql.WriteString(fn)
sql.WriteString(",")
}
}
return sql.String()
}
func appendQueryCondition(mf *MultiFunction) string {
condition := bytes.NewBufferString(" where 1=1 ")
if len(mf.Ids) > 0 {
condition.WriteString("and ")
condition.WriteString(mf.Id)
condition.WriteString(" in(")
fmt.Println(len(mf.Ids))
for i := 0; i < len(mf.Ids); i++ {
condition.WriteString("?,")
mf.cv[mf.tag] = mf.Ids[i]
mf.tag++
}
temp := condition.String()
temp1 := temp[0 : len(temp)-1]
condition = bytes.NewBufferString(temp1)
condition.WriteString(")")
}
if len(mf.Rule) > 0 {
rule := mf.Rule
if len(rule) > 0 {
o := mf.Model
f := reflect.ValueOf(o).Elem()
for k, v := range rule {
if strings.Index(v, "@_data_@") != -1 {
v2 := strings.Split(v, "@_data_@")
fn := f.FieldByName(k)
fv := fn.String()
fv2 := strings.Split(fv, "@_data_@")
condition.WriteString(" and ")
condition.WriteString(k)
condition.WriteString(v2[0])
mf.cv[mf.tag] = fv2[0]
mf.tag++
condition.WriteString(" and ")
condition.WriteString(k)
condition.WriteString(v2[1])
mf.cv[mf.tag] = fv2[1]
mf.tag++
} else {
fn := f.FieldByName(k)
if util.IsStringT(fn.Type()) {
fv := fn.String()
condition.WriteString(" and ")
condition.WriteString(k)
condition.WriteString(v)
if v == LIKE {
fv1 := "%" + fv + "%"
mf.cv[mf.tag] = fv1
mf.tag++
} else {
mf.cv[mf.tag] = fv
mf.tag++
}
} else if util.IsIntT(fn.Type()) {
fv := fn.Int()
condition.WriteString("and ")
condition.WriteString(k)
condition.WriteString(v)
if v == LIKE {
fv1 := "%" + strconv.FormatInt(fv, 10) + "%"
mf.cv[mf.tag] = fv1
mf.tag++
} else {
mf.cv[mf.tag] = fv
mf.tag++
}
}
}
}
}
}
return condition.String()
}
func GetDBConnect(mf *MultiFunction) (*sql.DB, error) {
db, openConnectErr := sql.Open(mf.DbType, mf.DbAddr)
if openConnectErr != nil {
fmt.Println("open database connect failed")
util.HandleErr(openConnectErr)
return nil, openConnectErr
}
pingErr := db.Ping()
if pingErr != nil {
fmt.Printf("ping %s failed", mf.DbAddr)
util.HandleErr(pingErr)
return nil, pingErr
}
return db, nil
}
func GetDBConnectAddr(filePath string) (dbType string, addr string, e error) {
rsl, e := util.ReadLine(filePath, 100)
if e != nil {
return
}
var db_type, db_userN, db_pwd, db_dbN, db_url, db_agt string
var flog bool = false
for _, v := range rsl {
if strings.Contains(v, "db_type") {
db_type = strings.Split(v, "=")[1]
} else if strings.Contains(v, "db_userN") {
db_userN = strings.Split(v, "=")[1]
} else if strings.Contains(v, "db_pwd") {
flog = true
db_pwd = strings.Split(v, "=")[1]
} else if strings.Contains(v, "db_dbN") {
db_dbN = strings.Split(v, "=")[1]
} else if strings.Contains(v, "db_url") {
db_url = strings.Split(v, "=")[1]
} else if strings.Contains(v, "db_agt") {
db_agt = strings.Split(v, "=")[1]
}
}
if !util.IsNotEmpty(db_type) {
fmt.Println("not found db_type properties or db_type no empty...")
}
if !util.IsNotEmpty(db_dbN) {
fmt.Println("not found db_dbN properties or db_dbN no empty...")
}
if !util.IsNotEmpty(db_userN) {
fmt.Println("not found db_userN properties or db_userN no empty...")
}
if !util.IsNotEmpty(db_agt) {
fmt.Println("not found db_agt properties or db_agt no empty...")
}
if !util.IsNotEmpty(db_url) {
fmt.Println("not found db_url properties or db_url no empty...")
}
if !flog {
fmt.Println("not found db_pwd properties...")
}
ss := bytes.NewBufferString(db_userN)
ss.WriteString(":")
ss.WriteString(db_pwd)
ss.WriteString("@")
ss.WriteString(db_agt)
ss.WriteString("(")
ss.WriteString(db_url)
ss.WriteString(")/")
ss.WriteString(db_dbN)
dbType = db_type
addr = ss.String()
fmt.Println("get connect addr --->", ss.String())
return
}
const (
EQ = " = ? "
LIKE = " like ? "
GT = " > ? "
LT = " < ? "
GE = " >= ? "
LE = " <= ? "
)