mysql8 批量绑定提交性能优化(修改mysql源码)

        mysql5和mysql8 在批量绑定提交时,是一条一条的提交到服务端,待提交完后,服务端再处理所有的数据。由于是一条一条的提交数据,性能消耗在TCP来回的路上,示例代码如下:

{
    MYSQL_BIND ps_params[3];
    struct timeval tv1;
    struct timeval tv2;
    fis_ied fis_ied;
    int num = 10000;

    gettimeofday(&tv1,NULL);
    mysql_autocommit(mysql, 0);
    my_stpcpy(query, "INSERT INTO fis_ied(id, st_id, name) values(?, ?, ?)");
    stmt = mysql_simple_prepare(mysql, query);
    check_stmt(stmt);

    memset(ps_params, 0, sizeof (ps_params));
    ps_params[0].buffer_type = MYSQL_TYPE_LONGLONG;
    ps_params[0].buffer = (char *)&fis_ied.id;

    ps_params[1].buffer_type = MYSQL_TYPE_LONGLONG;
    ps_params[1].buffer = (char *)&fis_ied.st_id;

    ps_params[2].buffer_type = MYSQL_TYPE_STRING;

    
    for (int i = 0; i < num; ++i){
      
      fis_ied.id = i;
      fis_ied.st_id = 1;
      sprintf(query, "测试%d", i);


      ps_params[2].buffer = query;

      unsigned long length = strlen(query);
      ps_params[2].length = &length;
      rc = mysql_stmt_bind_param(stmt, ps_params);
      check_execute(stmt, rc);
      rc = mysql_stmt_execute(stmt);
      check_execute(stmt, rc);
    }

    rc = mysql_commit(mysql);
    myquery(rc);
    mysql_autocommit(mysql, 1);
    check_execute(stmt, rc);
    
    mysql_stmt_close(stmt);
    gettimeofday(&tv2,NULL);
    long startTimeUsec = tv1.tv_sec*1000000 + tv1.tv_usec;
    long endTimeUsec = tv2.tv_sec*1000000 + tv2.tv_usec;
    printf("时间消耗:  %ld\n",(endTimeUsec - startTimeUsec));
  }

条目数为1万条的时候,时间消耗约3秒。

解决方案:将消息从一万条变成一次性提交,以下代码基于8.0.12版本:

应用层代码修改变化:

{
    MYSQL_BIND ps_params[3];
    struct timeval tv1;
    struct timeval tv2;
    gettimeofday(&tv1,NULL);
    mysql_autocommit(mysql, 0);
    my_stpcpy(query, "INSERT INTO fis_ied(id, st_id, name) values(?, ?, ?)");
    stmt = mysql_simple_prepare(mysql, query);
    check_stmt(stmt);

    /* Init PS-parameters. */

    memset(ps_params, 0, sizeof (ps_params));
    

    int num = 10000;
    vector<fis_ied> v;
    v.resize(num);

    vector<char*> nameString;
    vector<unsigned long> nameStringLength;
    nameString.resize(num);
    nameStringLength.resize(num);

    for (int i = 0; i < num; ++i){
      v[i].id = i;
      v[i].st_id = 1;
      sprintf(query, "测试%d", i);
      v[i].name = query;
      nameString[i] = (char*)v[i].name.c_str();
      nameStringLength[i] = strlen(query);
    }

    /* - v0 -- INT */

    ps_params[0].buffer_type = MYSQL_TYPE_LONGLONG;
    ps_params[0].buffer = (char *)&v[0].id;

    ps_params[1].buffer_type = MYSQL_TYPE_LONGLONG;
    ps_params[1].buffer = (char *)&v[0].st_id;

    ps_params[2].buffer_type = MYSQL_TYPE_STRING;
    ps_params[2].buffer = (char **)&nameString[0];
    ps_params[2].length = &nameStringLength[0];

    /* Bind parameters. */

    rc = mysql_stmt_bind_param(stmt, ps_params);

    array_size = num;
    rc = mysql_stmt_attr_set(stmt, STMT_ATTR_ARRAY_SIZE, &array_size);

    row_size = sizeof(struct fis_ied);
    mysql_stmt_attr_set(stmt, STMT_ATTR_ROW_SIZE, &row_size);
    /* Execute! */

    rc = mysql_stmt_execute(stmt);
    check_execute(stmt, rc);
    rc = mysql_commit(mysql);
    myquery(rc);
    mysql_autocommit(mysql, 1);
    check_execute(stmt, rc);
    gettimeofday(&tv2,NULL);

    long startTimeUsec = tv1.tv_sec*1000000 + tv1.tv_usec;
    long endTimeUsec = tv2.tv_sec*1000000 + tv2.tv_usec;
    mysql_stmt_close(stmt);
    printf("时间消耗:  %ld\n",(endTimeUsec - startTimeUsec));
    
  }

以上代码中,增加了STMT_ATTR_ARRAY_SIZE、STMT_ATTR_ROW_SIZE

STMT_ATTR_ARRAY_SIZE告诉说服务端这个消息是一个数组

STMT_ATTR_ROW_SIZE用于客户端序列化时使用

客户端主要代码修改:

在序列化时,增加上送array_size,同时,将所有数据一次性上送完

cli_stmt_execute定义修改:

int cli_stmt_execute(MYSQL_STMT *stmt) {
  DBUG_ENTER("cli_stmt_execute");

  if (stmt->param_count) {
    MYSQL *mysql = stmt->mysql;
    NET *net = &mysql->net;
    MYSQL_BIND *param, *param_end;
    char *param_data;
    ulong length;
    uint null_count;
    bool result;
    uint array_iter;

    if (!stmt->bind_param_done) {
      set_stmt_error(stmt, CR_PARAMS_NOT_BOUND, unknown_sqlstate, NULL);
      DBUG_RETURN(1);
    }
    if (mysql->status != MYSQL_STATUS_READY ||
        mysql->server_status & SERVER_MORE_RESULTS_EXISTS) {
      set_stmt_error(stmt, CR_COMMANDS_OUT_OF_SYNC, unknown_sqlstate, NULL);
      DBUG_RETURN(1);
    }

    if (net->vio)
      net_clear(net, 1); /* Sets net->write_pos */
    else {
      set_stmt_errmsg(stmt, net);
      DBUG_RETURN(1);
    }

    /* 批量绑定提交个数 */
    int3store(net->write_pos, stmt->array_size);
    net->write_pos += 3;

    /* Reserve place for null-marker bytes */
    null_count = (stmt->param_count + 7) / 8;
    if (my_realloc_str(net, null_count + 1)) {
      set_stmt_errmsg(stmt, net);
      DBUG_RETURN(1);
    }
    memset(net->write_pos, 0, null_count);
    net->write_pos += null_count;
    param_end = stmt->params + stmt->param_count;

    /* In case if buffers (type) altered, indicate to server */
    *(net->write_pos)++ = (uchar)stmt->send_types_to_server;
    if (stmt->send_types_to_server) {
      if (my_realloc_str(net, 2 * stmt->param_count)) {
        set_stmt_errmsg(stmt, net);
        DBUG_RETURN(1);
      }
      /*
        Store types of parameters in first in first package
        that is sent to the server.
      */
      for (param = stmt->params; param < param_end; param++)
        store_param_type(&net->write_pos, param);
    }

    for (array_iter = 0; array_iter < stmt->array_size; ++array_iter){
        for (param = stmt->params; param < param_end; param++) {
          /* check if mysql_stmt_send_long_data() was used */
          if (param->long_data_used)
            param->long_data_used = 0; /* Clear for next execute call */
          else if (store_param(stmt, array_iter, param))
            DBUG_RETURN(1);
        }
    }
    
    length = (ulong)(net->write_pos - net->buff);
    /* TODO: Look into avoding the following memdup */
    if (!(param_data = pointer_cast<char *>(
              my_memdup(PSI_NOT_INSTRUMENTED, net->buff, length, MYF(0))))) {
      set_stmt_error(stmt, CR_OUT_OF_MEMORY, unknown_sqlstate, NULL);
      DBUG_RETURN(1);
    }
    result = execute(stmt, param_data, length);
    stmt->send_types_to_server = 0;
    my_free(param_data);
    DBUG_RETURN(result);
  }
  DBUG_RETURN((int)execute(stmt, 0, 0));
}

服务端主要代码修改:

解析报文时,增加解析array_size,同时,把所有数据全部解析出来。

Protocol_classic::parse_packet反序列化报文代码修改:

bool Protocol_classic::parse_packet(union COM_DATA *data,
                                    enum_server_command cmd) {
  DBUG_ENTER("Protocol_classic::parse_packet");
  switch (cmd) {
    case COM_INIT_DB: {
      data->com_init_db.db_name =
          reinterpret_cast<const char *>(input_raw_packet);
      data->com_init_db.length = input_packet_length;
      break;
    }
    case COM_REFRESH: {
      if (input_packet_length < 1) goto malformed;
      data->com_refresh.options = input_raw_packet[0];
      break;
    }
    case COM_PROCESS_KILL: {
      if (input_packet_length < 4) goto malformed;
      data->com_kill.id = (ulong)uint4korr(input_raw_packet);
      break;
    }
    case COM_SET_OPTION: {
      if (input_packet_length < 2) goto malformed;
      data->com_set_option.opt_command = uint2korr(input_raw_packet);
      break;
    }
    case COM_STMT_EXECUTE: {
      if (input_packet_length < 9) goto malformed;
      uchar *read_pos = input_raw_packet;
      size_t packet_left = input_packet_length;
      char err_str[128];

      // Get the statement id
      data->com_stmt_execute.stmt_id = uint4korr(read_pos);
      read_pos += 4;
      packet_left -= 4;
      // Get execution flags
      data->com_stmt_execute.open_cursor = static_cast<bool>(*read_pos);
      read_pos += 5;
      packet_left -= 5;
      DBUG_PRINT("info", ("stmt %lu", data->com_stmt_execute.stmt_id));
      DBUG_PRINT("info", ("Flags %lu", data->com_stmt_execute.open_cursor));

      // Get the statement by id
      Prepared_statement *stmt =
          m_thd->stmt_map.find(data->com_stmt_execute.stmt_id);
      data->com_stmt_execute.parameter_count = 0;

      /*
        If no statement found there's no need to generate error.
        It will be generated in sql_parse.cc which will check again for the id.
      */
      if (!stmt || stmt->param_count < 1) break;

      // Get the array size
      stmt->array_size = uint3korr(read_pos);
      read_pos += 3;
      packet_left -= 3;

      uint param_count = stmt->param_count * stmt->array_size;

      DBUG_PRINT("info", ("array_size:%d, count:%d", stmt->array_size, param_count));
      data->com_stmt_execute.parameters =
          static_cast<PS_PARAM *>(m_thd->alloc(param_count * sizeof(PS_PARAM)));
      if (!data->com_stmt_execute.parameters)
        goto malformed; /* purecov: inspected */

      /* Then comes the null bits */
      const uint null_bits_packet_len = (stmt->param_count + 7) / 8;
      if (packet_left < null_bits_packet_len) goto malformed;
      unsigned char *null_bits = read_pos;
      read_pos += null_bits_packet_len;
      packet_left -= null_bits_packet_len;

      PS_PARAM *params = data->com_stmt_execute.parameters;

      /* Then comes the types byte. If set, new types are provided */
      if (!packet_left) goto malformed;
      bool has_new_types = static_cast<bool>(*read_pos++);
      --packet_left;
      data->com_stmt_execute.has_new_types = has_new_types;
      if (has_new_types) {
        DBUG_PRINT("info", ("Types provided"));
        for (uint i = 0; i < param_count; ++i) {
          if(i < stmt->param_count){
            if (packet_left < 2) goto malformed;

            ushort type_code = sint2korr(read_pos);
            read_pos += 2;
            packet_left -= 2;

            const uint signed_bit = 1 << 15;
            params[i].type =
                static_cast<enum enum_field_types>(type_code & ~signed_bit);
            params[i].unsigned_type = static_cast<bool>(type_code & signed_bit);
          }
          else{
            params[i].type = params[i % stmt->param_count].type;
            params[i].unsigned_type = params[i % stmt->param_count].unsigned_type;
          }

          DBUG_PRINT("info", ("type=%u", (uint)params[i].type));
          DBUG_PRINT("info", ("flags=%u", (uint)params[i].unsigned_type));
        }
      }
      /*
        No check for packet_left here or in case of only long data
        we will return malformed, although the packet will be correct
      */

      /* Here comes the real data */
      for (uint i = 0; i < param_count; ++i) {
        uint column = i % stmt->param_count;
        params[i].null_bit =
            static_cast<bool>(null_bits[column / 8] & (1 << (column & 7)));
        // Check if parameter is null
        if (params[i].null_bit) {
          DBUG_PRINT("info", ("null param"));
          params[i].value = nullptr;
          params[i].length = 0;
          data->com_stmt_execute.parameter_count++;
          continue;
        }
        enum enum_field_types type =
            has_new_types ? params[i].type : stmt->param_array[column]->data_type();
        if (stmt->param_array[column]->state == Item_param::LONG_DATA_VALUE) {
          DBUG_PRINT("info", ("long data"));
          if (!((type >= MYSQL_TYPE_TINY_BLOB) && (type <= MYSQL_TYPE_STRING))){
            snprintf(err_str, sizeof(err_str), "%s:%d#type:%d", __FILE__, __LINE__, type);
            goto malformed;
          }
          data->com_stmt_execute.parameter_count++;

          continue;
        }

        bool buffer_underrun = false;
        ulong header_len;

        // Set parameter length.
        params[i].length = get_ps_param_len(type, read_pos, packet_left,
                                            &header_len, &buffer_underrun);
        if (buffer_underrun){
          snprintf(err_str, sizeof(err_str), "%s:%d#buffer_underrun[%d]", __FILE__, __LINE__, i);
          goto malformed;
        }

        read_pos += header_len;
        packet_left -= header_len;

        // Set parameter value
        params[i].value = read_pos;
        read_pos += params[i].length;
        packet_left -= params[i].length;
        data->com_stmt_execute.parameter_count++;
        DBUG_PRINT("info", ("param len %ul", (uint)params[i].length));
      }
      DBUG_PRINT("info", ("param count %ul",
                          (uint)data->com_stmt_execute.parameter_count));
      break;
    }
    case COM_STMT_FETCH: {
      if (input_packet_length < 8) goto malformed;
      data->com_stmt_fetch.stmt_id = uint4korr(input_raw_packet);
      data->com_stmt_fetch.num_rows = uint4korr(input_raw_packet + 4);
      break;
    }
    case COM_STMT_SEND_LONG_DATA: {
      if (input_packet_length < MYSQL_LONG_DATA_HEADER) goto malformed;
      data->com_stmt_send_long_data.stmt_id = uint4korr(input_raw_packet);
      data->com_stmt_send_long_data.param_number =
          uint2korr(input_raw_packet + 4);
      data->com_stmt_send_long_data.longdata = input_raw_packet + 6;
      data->com_stmt_send_long_data.length = input_packet_length - 6;
      break;
    }
    case COM_STMT_PREPARE: {
      data->com_stmt_prepare.query =
          reinterpret_cast<const char *>(input_raw_packet);
      data->com_stmt_prepare.length = input_packet_length;
      break;
    }
    case COM_STMT_CLOSE: {
      if (input_packet_length < 4) goto malformed;

      data->com_stmt_close.stmt_id = uint4korr(input_raw_packet);
      break;
    }
    case COM_STMT_RESET: {
      if (input_packet_length < 4) goto malformed;

      data->com_stmt_reset.stmt_id = uint4korr(input_raw_packet);
      break;
    }
    case COM_QUERY: {
      data->com_query.query = reinterpret_cast<const char *>(input_raw_packet);
      data->com_query.length = input_packet_length;
      break;
    }
    case COM_FIELD_LIST: {
      /*
        We have name + wildcard in packet, separated by endzero
      */
      ulong len = strend((char *)input_raw_packet) - (char *)input_raw_packet;

      if (len >= input_packet_length || len > NAME_LEN) goto malformed;

      data->com_field_list.table_name = input_raw_packet;
      data->com_field_list.table_name_length = len;

      data->com_field_list.query = input_raw_packet + len + 1;
      data->com_field_list.query_length = input_packet_length - len;
      break;
    }
    default:
      break;
  }

  DBUG_RETURN(false);

malformed:
  my_error(ER_MALFORMED_PACKET, MYF(0));
  bad_packet = true;
  DBUG_RETURN(true);
}

sql_prepare.cc mysqld_stmt_execute代码修改:

void mysqld_stmt_execute(THD *thd, Prepared_statement *stmt, bool has_new_types,
                         ulong execute_flags, PS_PARAM *parameters) {
  DBUG_ENTER("mysqld_stmt_execute");

#if defined(ENABLED_PROFILING)
  thd->profiling->set_query_source(stmt->m_query_string.str,
                                   stmt->m_query_string.length);
#endif
  DBUG_PRINT("info", ("stmt: %p", stmt));

  bool switch_protocol = thd->is_classic_protocol();
  if (switch_protocol) {
    // set the current client capabilities before switching the protocol
    thd->protocol_binary.set_client_capabilities(
        thd->get_protocol()->get_client_capabilities());
    thd->push_protocol(&thd->protocol_binary);
  }

  MYSQL_EXECUTE_PS(thd->m_statement_psi, stmt->m_prepared_stmt);

  // Query text for binary, general or slow log, if any of them is open
  String expanded_query;
  // If no error happened while setting the parameters, execute statement.
  
  if (!stmt->set_parameters(&expanded_query, has_new_types, parameters)) {
    for(uint arrayIter = 0; arrayIter < stmt->array_size; ++arrayIter){
      if(!stmt->insert_params(&expanded_query, arrayIter, parameters)){
        bool open_cursor =
            static_cast<bool>(execute_flags & (ulong)CURSOR_TYPE_READ_ONLY);
        stmt->execute_loop(&expanded_query, open_cursor);
      }
    }
    
  }

  if (switch_protocol) thd->pop_protocol();

  sp_cache_enforce_limit(thd->sp_proc_cache, stored_program_cache_size);
  sp_cache_enforce_limit(thd->sp_func_cache, stored_program_cache_size);

  /* Close connection socket; for use with client testing (Bug#43560). */
  DBUG_EXECUTE_IF("close_conn_after_stmt_execute",
                  thd->get_protocol()->shutdown(););

  DBUG_VOID_RETURN;
}

以上代码修改完后,同样的1万条数据,耗时从3秒降为208毫秒

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值