gorm mysql批量插入数据

2 篇文章 0 订阅

记录下项目中自己写的批量插入数据代码

/**
 * 获取批量添加数据sql语句
 */
func getBranchInsertSql(data interface{}, tableName string) (string, error) {
	var isArr bool
	dataValue := reflect.ValueOf(data)
	switch dataValue.Kind() {
	case reflect.Array, reflect.Slice:
		// 数组
		isArr = true
	case reflect.Struct:
		// 不是数组,单个元素,结构体
		isArr = false
	default:
		// 既不是结构体也不是数组,报错
		return "", errors.New("data type err, must be array or struct")
	}

	var arr []interface{}
	if isArr {
		for i := 0; i < dataValue.Len(); i++ {
			tmp := dataValue.Index(i)
			arr = append(arr, tmp.Interface())
		}
	} else {
		arr = append(arr, data)
	}

	return getBranchInsertSqlByArray(arr, tableName)
}

func getBranchInsertSqlByArray(objArr []interface{}, tableName string) (string, error) {
	// 转为数组
	if len(objArr) == 0 {
		return "", nil
	}
	fieldName := ""
	var valueTypeList []reflect.Kind
	fieldNum := reflect.TypeOf(objArr[0]).NumField()
	fieldT := reflect.TypeOf(objArr[0])
	for a := 0; a < fieldNum; a++ {
		name := getColumnName(fieldT.Field(a).Tag.Get("gorm"))
		// 添加字段名
		if a == fieldNum-1 {
			fieldName += fmt.Sprintf("`%s`", name)
		} else {
			fieldName += fmt.Sprintf("`%s`,", name)
		}

		valueTypeList = append(valueTypeList, fieldT.Field(a).Type.Kind())
	}
	var valueList []string
	for _, obj := range objArr {
		objV := reflect.ValueOf(obj)
		v := "("
		for index, t := range valueTypeList {
			var value string
			var err error
			if index == fieldNum - 1 {
				value, err = getFormatField(objV, index, t, "")
			} else {
				value, err = getFormatField(objV, index, t, ",")
			}
			if err != nil {
				return "", err
			}
			v += value
		}
		v += ")"
		valueList = append(valueList, v)
	}
	insertSql := fmt.Sprintf("insert into `%s` (%s) values %s", tableName, fieldName, strings.Join(valueList, ",") + ";")
	return insertSql, nil
}

// getFormatField 获取字段类型值转为字符串
func getFormatField(objV reflect.Value, index int, t reflect.Kind, sep string) (string, error) {
	v := ""

	switch t {
	case reflect.String:
		v += fmt.Sprintf("'%s'%s", objV.Field(index).String(), sep)
	case reflect.Bool:
		v += fmt.Sprintf("%t%s", objV.Field(index).Bool(), sep)
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		v += fmt.Sprintf("%d%s", objV.Field(index).Int(), sep)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		v += fmt.Sprintf("%d%s", objV.Field(index).Uint(), sep)
	case reflect.Float32, reflect.Float64:
		v += fmt.Sprintf("%f%s", objV.Field(index).Float(), sep)
	case reflect.Interface, reflect.Ptr, reflect.Uintptr:
		return "", errors.New(fmt.Sprintf("batch insert unsupport type %s", t.String()))
	}

	return v, nil
}

// GetColumnName 获取字段名
func getColumnName(jsonName string) string {
	for _, name := range strings.Split(jsonName, ";") {
		if strings.Index(name, "column") == -1 {
			continue
		}
		return strings.Replace(name, "column:", "", 1)
	}
	return ""
}

// 如果超过一百条, 则分批插入
const batchInsertSize = 100
// batchCreateModelsByPage 分页批量插入
func batchCreateModelsByPage(tx *gorm.DB, data interface{}, tableName string) (db *gorm.DB, err error) {
	var isArr bool
	dataValue := reflect.ValueOf(data)
	switch dataValue.Kind() {
	case reflect.Array, reflect.Slice:
		// 数组
		isArr = true
	case reflect.Struct:
		// 不是数组,单个元素,结构体
		isArr = false
	default:
		// 既不是结构体也不是数组,报错
		return tx, errors.New("data type err, must be array or struct")
	}

	var dataList []interface{}
	if isArr {
		for i := 0; i < dataValue.Len(); i++ {
			tmp := dataValue.Index(i)
			dataList = append(dataList, tmp.Interface())
		}
	} else {
		dataList = append(dataList, data)
	}

	if len(dataList) == 0 {
		return tx, nil
	}
	page := len(dataList) / batchInsertSize
	if len(dataList) % batchInsertSize != 0 {
		page += 1
	}
	for a := 1; a <= page; a++ {
		var bills []interface{}
		if a == page {
			bills = dataList[(a-1) * batchInsertSize:]
		} else {
			bills = dataList[(a-1) * batchInsertSize : a * batchInsertSize]
		}
		sql, err := getBranchInsertSqlByArray(bills, tableName)
		if err != nil {
			return tx, err
		}
		db = tx.Exec(sql)
		if err = db.Error; err != nil {
			return db, errors.New(fmt.printf("batch create data error: %v, sql: %s, tableName: %s",
				err, sql, tableName))
		}
	}
	return db, nil
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值