1,初始化配置
package settings
import (
"flag"
"github.com/go-ini/ini"
"go.uber.org/zap"
"log"
)
var (
Cfg *ini.File
PgHost string
PgPort int
PgUser string
PgPassword string
PgDatabase string
)
func Init() {
var err error
var env string
flag.StringVar(&env, "env", "test", "example':test'")
flag.Parse()
Cfg, err = ini.Load("conf/app.ini")
if err != nil {
zap.L().Info("Fail to parse 'conf/app.ini'")
panic(err)
}
log.Print(env)
LoadDB()
}
func LoadDB() {
dbc, err := Cfg.GetSection("database")
if err != nil {
log.Fatalf("Fail to get section 'database': %v", err)
}
PgHost = dbc.Key("HOST").MustString("127.0.0.1")
PgPort = dbc.Key("PORT").MustInt(5432)
PgUser = dbc.Key("USER").MustString("postgres")
PgPassword = dbc.Key("PASSWORD").MustString("postgres")
PgDatabase = dbc.Key("DBNAME").MustString("dw")
}
import (
"database/sql"
"errors"
"fmt"
_ "github.com/lib/pq"
"go.uber.org/zap"
"log"
"pkg/settings"
//"github.com/go-pg/pg/v10"
)
type PgDB struct {
pool chan *sql.DB
}
var ErrPoolClosed = errors.New("连接池已经关闭!")
var pgdb PgDB
func Init() {
pgdb.New(1000)
}
func (pgdb *PgDB) New(size int) {
pgdb.pool = make(chan *sql.DB, size)
}
func (pgdb *PgDB) Close() {
close(pgdb.pool)
for db := range pgdb.pool {
err := db.Close()
if err != nil {
zap.L().Error(err.Error())
}
}
log.Println("Close:", "资源回收成功!")
}
func (pgdb *PgDB) createPgConn() *sql.DB {
pgStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
settings.PgHost, settings.PgPort, settings.PgUser, settings.PgPassword, settings.PgDatabase)
db, err := sql.Open("postgres", pgStr)
if err != nil {
log.Println("DB Open:", err)
fmt.Println("============================", err.Error())
return nil
}
err = db.Ping()
if err != nil {
log.Println("DB PING: ", err)
fmt.Println("============================", err.Error())
return nil
}
//fmt.Println("pg connection success...")
return db
}
func (pgdb *PgDB) getPgConn() (*sql.DB, error) {
select {
case r, ok := <-pgdb.pool:
if !ok {
return nil, ErrPoolClosed
}
return r, nil
default:
//log.Println("getDBConn:", "创建新的资源...")
return pgdb.createPgConn(), nil
}
}
func (pgdb *PgDB) putPgConn(db *sql.DB) {
select {
case pgdb.pool <- db:
//log.Println("putDBConn", "连接放入池中...")
default:
//log.Println("putDBConn", "队列已满,关闭当前连接...")
err := db.Close()
if err != nil {
//panic(err)
zap.L().Error(err.Error())
}
}
}
func (pgdb *PgDB) ReadOne(readSql string, dest []string) ([]string, error) {
db, err := pgdb.getPgConn()
if err != nil {
//panic(err)
zap.L().Error(err.Error())
return nil, err
}
defer pgdb.putPgConn(db)
receiver := make([]interface{}, len(dest))
for i := range receiver {
receiver[i] = &dest[i]
}
err = db.QueryRow(readSql).Scan(receiver...)
//if r, _ := row.RowsAffected(); r == 0 {
// return make([]string, 0), nil
//}
//err = row.Scan(receiver...)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
zap.L().Error(err.Error())
return nil, err
}
return dest, nil
}
func (pgdb *PgDB) ReadMany(readSql string, dest []string) ([]map[string]interface{}, error) {
db, err := pgdb.getPgConn()
if err != nil {
zap.L().Error(err.Error())
return nil, err
}
pgdb.putPgConn(db)
var result []map[string]interface{}
rows, err := db.Query(readSql)
if err != nil {
zap.L().Error(err.Error())
return nil, err
}
columns, _, _ := GetQueryColumns(rows)
for rows.Next() {
receiver := make([]interface{}, len(dest))
for i := range receiver {
receiver[i] = &dest[i]
}
err = rows.Scan(receiver...)
if err != nil {
zap.L().Error(err.Error())
return nil, err
}
item := make(map[string]interface{})
for i, data := range receiver {
item[columns[i]] = *data.(*string)
}
result = append(result, item)
}
return result, nil
}
func GetQueryColumns(rows *sql.Rows) ([]string, map[string]string, error) {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, nil, err
}
length := len(columnTypes)
columns := make([]string, length)
columnTypeMap := make(map[string]string, length)
for i, ct := range columnTypes {
columns[i] = ct.Name()
columnTypeMap[ct.Name()] = ct.DatabaseTypeName()
}
return columns, columnTypeMap, nil
}
func (pgdb *PgDB) Insert(insertSql string, pkey string) (int, error) {
db, err := pgdb.getPgConn()
if err != nil {
log.Println(err)
fmt.Println("=====================================================")
}
defer pgdb.putPgConn(db)
lastInsertId := 0
if pkey != "" {
insertSql = fmt.Sprintf("%s RETURNING %s", insertSql, pkey)
}
err = db.QueryRow(insertSql).Scan(&lastInsertId)
return lastInsertId, err
}
func (pgdb *PgDB) Exec(sql string) (int, error) {
db, err := pgdb.getPgConn()
if err != nil {
log.Println(err)
}
defer pgdb.putPgConn(db)
result, err := db.Exec(sql)
if err != nil {
return 0, err
}
affectId, _ := result.RowsAffected()
return int(affectId), nil
}
2,utlis包装一下,将map转成sql语句
package models
func CreateOrUpdateRow(table string, pkey string, columns []string, data map[string]interface{}) (int, error) {
sql := BuildMap2Sql(table, pkey, columns, data)
//fmt.Println(sql)
affectId, err := pgdb.Exec(sql)
if err != nil {
return 0, err
}
return affectId, nil
}
func Insert(table string, pkey string, columns []string, data map[string]interface{}) (int, error) {
sql := utils.BuildMap2Sql(table, "", columns, data)
return pgdb.Insert(sql, pkey)
}
func ReadOne(selectSql string, dest []string) ([]string, error) {
return pgdb.ReadOne(selectSql, dest)
}
func ReadMany(selectSql string, dest []string) ([]map[string]interface{}, error) {
return pgdb.ReadMany(selectSql, dest)
}
func Exec(sql string) (int, error) {
return pgdb.Exec(sql)
}
func BuildMap2Sql(table string, pkey string, columns []string, data map[string]interface{}) string {
insertSql, updateSql, sql := "", "", ""
var newColumns = make([]string, 0)
for _, column := range columns {
v, ok := data[column]
if ok {
vv := Interface2Str(v)
if len(vv) < 1 {
continue
}
newColumns = append(newColumns, column)
insertSql += "'" + vv + "',"
if column != pkey && column != "create_time" {
updateSql += column + "='" + vv + "',"
}
}
}
insertSql = strings.TrimRight(insertSql, ",")
updateSql = strings.TrimRight(updateSql, ",")
if pkey != "" {
sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s) ON CONFLICT(%s) DO UPDATE set %s", table,
strings.Join(newColumns, ","), insertSql, pkey, updateSql)
} else {
sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s)", table, strings.Join(newColumns, ","), insertSql)
}
return sql
}
3,调用
affectId, err := models.CreateOrUpdateRow("dw.quota_auto", "userid,leave_code,quota_cycle", quotaColums, data)
fmt.Println(i, affectId, err)