Golang基于sql.DB完成PG原生SQL操作

1,初始化配置

package settings

import (
	"flag"
	"github.com/go-ini/ini"
	"go.uber.org/zap"
	"log"
)

var (
	Cfg        *ini.File
	PgHost     string
	PgPort     int
	PgUser     string
	PgPassword string
	PgDatabase string
)

func Init() {
	var err error
	var env string
	flag.StringVar(&env, "env", "test", "example':test'")
	flag.Parse()
	Cfg, err = ini.Load("conf/app.ini")
	if err != nil {
		zap.L().Info("Fail to parse 'conf/app.ini'")
		panic(err)
	}
	log.Print(env)
	LoadDB()
}

func LoadDB() {
	dbc, err := Cfg.GetSection("database")
	if err != nil {
		log.Fatalf("Fail to get section 'database': %v", err)
	}

	PgHost = dbc.Key("HOST").MustString("127.0.0.1")
	PgPort = dbc.Key("PORT").MustInt(5432)
	PgUser = dbc.Key("USER").MustString("postgres")
	PgPassword = dbc.Key("PASSWORD").MustString("postgres")
	PgDatabase = dbc.Key("DBNAME").MustString("dw")
}

import (
	"database/sql"
	"errors"
	"fmt"
	_ "github.com/lib/pq"
	"go.uber.org/zap"
	"log"
	"pkg/settings"
	//"github.com/go-pg/pg/v10"
)

type PgDB struct {
	pool chan *sql.DB
}

var ErrPoolClosed = errors.New("连接池已经关闭!")

var pgdb PgDB

func Init() {
	pgdb.New(1000)
}

func (pgdb *PgDB) New(size int) {
	pgdb.pool = make(chan *sql.DB, size)

}

func (pgdb *PgDB) Close() {
	close(pgdb.pool)
	for db := range pgdb.pool {
		err := db.Close()
		if err != nil {
			zap.L().Error(err.Error())
		}
	}
	log.Println("Close:", "资源回收成功!")
}

func (pgdb *PgDB) createPgConn() *sql.DB {
	pgStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
		settings.PgHost, settings.PgPort, settings.PgUser, settings.PgPassword, settings.PgDatabase)
	db, err := sql.Open("postgres", pgStr)
	if err != nil {
		log.Println("DB Open:", err)
		fmt.Println("============================", err.Error())
		return nil
	}
	err = db.Ping()
	if err != nil {
		log.Println("DB PING: ", err)
		fmt.Println("============================", err.Error())
		return nil
	}
	//fmt.Println("pg connection success...")
	return db
}

func (pgdb *PgDB) getPgConn() (*sql.DB, error) {
	select {
	case r, ok := <-pgdb.pool:
		if !ok {
			return nil, ErrPoolClosed
		}
		return r, nil
	default:
		//log.Println("getDBConn:", "创建新的资源...")
		return pgdb.createPgConn(), nil
	}

}

func (pgdb *PgDB) putPgConn(db *sql.DB) {
	select {
	case pgdb.pool <- db:
		//log.Println("putDBConn", "连接放入池中...")
	default:
		//log.Println("putDBConn", "队列已满,关闭当前连接...")
		err := db.Close()
		if err != nil {
			//panic(err)
			zap.L().Error(err.Error())
		}
	}
}

func (pgdb *PgDB) ReadOne(readSql string, dest []string) ([]string, error) {
	db, err := pgdb.getPgConn()
	if err != nil {
		//panic(err)
		zap.L().Error(err.Error())
		return nil, err
	}
	defer pgdb.putPgConn(db)
	receiver := make([]interface{}, len(dest))
	for i := range receiver {
		receiver[i] = &dest[i]
	}
	err = db.QueryRow(readSql).Scan(receiver...)
	//if r, _ := row.RowsAffected(); r == 0 {
	//	return make([]string, 0), nil
	//}

	//err = row.Scan(receiver...)
	if err == sql.ErrNoRows {
		return nil, nil
	}
	if err != nil {
		zap.L().Error(err.Error())
		return nil, err
	}
	return dest, nil
}

func (pgdb *PgDB) ReadMany(readSql string, dest []string) ([]map[string]interface{}, error) {
	db, err := pgdb.getPgConn()
	if err != nil {
		zap.L().Error(err.Error())
		return nil, err
	}
	pgdb.putPgConn(db)
	var result []map[string]interface{}
	rows, err := db.Query(readSql)
	if err != nil {
		zap.L().Error(err.Error())
		return nil, err
	}
	columns, _, _ := GetQueryColumns(rows)
	for rows.Next() {
		receiver := make([]interface{}, len(dest))
		for i := range receiver {
			receiver[i] = &dest[i]
		}
		err = rows.Scan(receiver...)
		if err != nil {
			zap.L().Error(err.Error())
			return nil, err
		}
		item := make(map[string]interface{})

		for i, data := range receiver {
			item[columns[i]] = *data.(*string)
		}
		result = append(result, item)
	}
	return result, nil
}

func GetQueryColumns(rows *sql.Rows) ([]string, map[string]string, error) {
	columnTypes, err := rows.ColumnTypes()

	if err != nil {
		return nil, nil, err
	}

	length := len(columnTypes)

	columns := make([]string, length)
	columnTypeMap := make(map[string]string, length)

	for i, ct := range columnTypes {
		columns[i] = ct.Name()
		columnTypeMap[ct.Name()] = ct.DatabaseTypeName()
	}

	return columns, columnTypeMap, nil
}

func (pgdb *PgDB) Insert(insertSql string, pkey string) (int, error) {
	db, err := pgdb.getPgConn()
	if err != nil {
		log.Println(err)
		fmt.Println("=====================================================")
	}
	defer pgdb.putPgConn(db)

	lastInsertId := 0
	if pkey != "" {
		insertSql = fmt.Sprintf("%s RETURNING %s", insertSql, pkey)
	}
	err = db.QueryRow(insertSql).Scan(&lastInsertId)
	return lastInsertId, err
}

func (pgdb *PgDB) Exec(sql string) (int, error) {
	db, err := pgdb.getPgConn()
	if err != nil {
		log.Println(err)
	}
	defer pgdb.putPgConn(db)
	result, err := db.Exec(sql)
	if err != nil {
		return 0, err
	}
	affectId, _ := result.RowsAffected()
	return int(affectId), nil
}

2,utlis包装一下,将map转成sql语句

package models


func CreateOrUpdateRow(table string, pkey string, columns []string, data map[string]interface{}) (int, error) {
	sql := BuildMap2Sql(table, pkey, columns, data)
	//fmt.Println(sql)
	affectId, err := pgdb.Exec(sql)
	if err != nil {
		return 0, err
	}
	return affectId, nil
}

func Insert(table string, pkey string, columns []string, data map[string]interface{}) (int, error) {
	sql := utils.BuildMap2Sql(table, "", columns, data)
	return pgdb.Insert(sql, pkey)
}

func ReadOne(selectSql string, dest []string) ([]string, error) {
	return pgdb.ReadOne(selectSql, dest)
}
func ReadMany(selectSql string, dest []string) ([]map[string]interface{}, error) {
	return pgdb.ReadMany(selectSql, dest)
}
func Exec(sql string) (int, error) {
	return pgdb.Exec(sql)
}

func BuildMap2Sql(table string, pkey string, columns []string, data map[string]interface{}) string {
	insertSql, updateSql, sql := "", "", ""
	var newColumns = make([]string, 0)
	for _, column := range columns {
		v, ok := data[column]
		if ok {
			vv := Interface2Str(v)
			if len(vv) < 1 {
				continue
			}
			newColumns = append(newColumns, column)
			insertSql += "'" + vv + "',"
			if column != pkey && column != "create_time" {
				updateSql += column + "='" + vv + "',"
			}
		}
	}
	insertSql = strings.TrimRight(insertSql, ",")
	updateSql = strings.TrimRight(updateSql, ",")
	if pkey != "" {
		sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s) ON CONFLICT(%s) DO UPDATE set %s", table,
			strings.Join(newColumns, ","), insertSql, pkey, updateSql)
	} else {
		sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s)", table, strings.Join(newColumns, ","), insertSql)
	}

	return sql
}

3,调用

affectId, err := models.CreateOrUpdateRow("dw.quota_auto", "userid,leave_code,quota_cycle", quotaColums, data)
fmt.Println(i, affectId, err)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值