TensorFlow2.8.0代码分析之例子\examples\speech_commands\test_streaming_accuracy之main函数

函数作用:对一个连续的音频流,使用音频识别模型,生成识别精度的统计数据。

提供你要识别的长音频文件和与声音对应的文字清单文件,统计模型识别有多少错误,匹配的百分比会被分类。

使用此工具的示例如下:

bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \

--wav=/tmp/streaming_test_bg.wav \

--graph=/tmp/conv_frozen.pb \

--labels=/tmp/speech_commands_train/conv_labels.txt \

--ground_truth=/tmp/streaming_test_labels.txt --verbose \

--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \

--suppression_ms=500 --time_tolerance_ms=1500

函数全局调用情况如下:

 

函数声明情况如下:

 

函数流程图如下:

 

函数逻辑顺序图如下:

 

函数原始代码如下:

int main(int argc, char* argv[]) {

  string wav = "";

  string graph = "";

  string labels = "";

  string ground_truth = "";

  string input_data_name = "decoded_sample_data:0";

  string input_rate_name = "decoded_sample_data:1";

  string output_name = "labels_softmax";

  int32_t clip_duration_ms = 1000;

  int32_t clip_stride_ms = 30;

  int32_t average_window_ms = 500;

  int32_t time_tolerance_ms = 750;

  int32_t suppression_ms = 1500;

  float detection_threshold = 0.7f;

  bool verbose = false;

  std::vector<Flag> flag_list = {

      Flag("wav", &wav, "audio file to be identified"),

      Flag("graph", &graph, "model to be executed"),

      Flag("labels", &labels, "path to file containing labels"),

      Flag("ground_truth", &ground_truth,

           "path to file containing correct times and labels of words in the "

           "audio as <word>,<timestamp in ms> lines"),

      Flag("input_data_name", &input_data_name,

           "name of input data node in model"),

      Flag("input_rate_name", &input_rate_name,

           "name of input sample rate node in model"),

      Flag("output_name", &output_name, "name of output node in model"),

      Flag("clip_duration_ms", &clip_duration_ms,

           "length of recognition window"),

      Flag("average_window_ms", &average_window_ms,

           "length of window to smooth results over"),

      Flag("time_tolerance_ms", &time_tolerance_ms,

           "maximum gap allowed between a recognition and ground truth"),

      Flag("suppression_ms", &suppression_ms,

           "how long to ignore others for after a recognition"),

      Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),

      Flag("detection_threshold", &detection_threshold,

           "what score is required to trigger detection of a word"),

      Flag("verbose", &verbose, "whether to log extra debugging information"),

  };

  string usage = tensorflow::Flags::Usage(argv[0], flag_list);

  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);

  if (!parse_result) {

    LOG(ERROR) << usage;

    return -1;

  }

  // We need to call this to set up global state for TensorFlow.设置TensorFlow

  tensorflow::port::InitMain(argv[0], &argc, &argv);

  if (argc > 1) {

    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;

    return -1;

  }

  // First we load and initialize the model.初始化模型

  std::unique_ptr<tensorflow::Session> session;

  Status load_graph_status = LoadGraph(graph, &session);

  if (!load_graph_status.ok()) {

    LOG(ERROR) << load_graph_status;

    return -1;

  }

  std::vector<string> labels_list;

  Status read_labels_status = ReadLabelsFile(labels, &labels_list);

  if (!read_labels_status.ok()) {

    LOG(ERROR) << read_labels_status;

    return -1;

  }

  std::vector<std::pair<string, int64_t>> ground_truth_list;

  Status read_ground_truth_status =

      tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);

  if (!read_ground_truth_status.ok()) {

    LOG(ERROR) << read_ground_truth_status;

    return -1;

  }

  string wav_string;

  Status read_wav_status = tensorflow::ReadFileToString(

      tensorflow::Env::Default(), wav, &wav_string);

  if (!read_wav_status.ok()) {

    LOG(ERROR) << read_wav_status;

    return -1;

  }

  std::vector<float> audio_data;

  uint32 sample_count;

  uint16 channel_count;

  uint32 sample_rate;

  Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(

      wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);

  if (!decode_wav_status.ok()) {

    LOG(ERROR) << decode_wav_status;

    return -1;

  }

  if (channel_count != 1) {

    LOG(ERROR) << "Only mono .wav files can be used, but input has "

               << channel_count << " channels.";

    return -1;

  }

  const int64_t clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;

  const int64_t clip_stride_samples = (clip_stride_ms * sample_rate) / 1000;

  Tensor audio_data_tensor(tensorflow::DT_FLOAT,

                           tensorflow::TensorShape({clip_duration_samples, 1}));

  Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));

  sample_rate_tensor.scalar<int32>()() = sample_rate;

  tensorflow::RecognizeCommands recognize_commands(

      labels_list, average_window_ms, detection_threshold, suppression_ms);

  std::vector<std::pair<string, int64_t>> all_found_words;

  tensorflow::StreamingAccuracyStats previous_stats;

  const int64_t audio_data_end = (sample_count - clip_duration_samples);

  for (int64_t audio_data_offset = 0; audio_data_offset < audio_data_end;

       audio_data_offset += clip_stride_samples) {

    const float* input_start = &(audio_data[audio_data_offset]);

    const float* input_end = input_start + clip_duration_samples;

    std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());

    // Actually run the audio through the model.用模型处理音频

    std::vector<Tensor> outputs;

    Status run_status = session->Run({{input_data_name, audio_data_tensor},

                                      {input_rate_name, sample_rate_tensor}},

                                     {output_name}, {}, &outputs);

    if (!run_status.ok()) {

      LOG(ERROR) << "Running model failed: " << run_status;

      return -1;

    }

    const int64_t current_time_ms = (audio_data_offset * 1000) / sample_rate;

    string found_command;

    float score;

    bool is_new_command;

    Status recognize_status = recognize_commands.ProcessLatestResults(

        outputs[0], current_time_ms, &found_command, &score, &is_new_command);

    if (!recognize_status.ok()) {

      LOG(ERROR) << "Recognition processing failed: " << recognize_status;

      return -1;

    }

    if (is_new_command && (found_command != "_silence_")) {

      all_found_words.push_back({found_command, current_time_ms});

      if (verbose) {

        tensorflow::StreamingAccuracyStats stats;

        tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,

                                           current_time_ms, time_tolerance_ms,

                                           &stats);

        int32_t false_positive_delta = stats.how_many_false_positives -

                                       previous_stats.how_many_false_positives;

        int32_t correct_delta = stats.how_many_correct_words -

                                previous_stats.how_many_correct_words;

        int32_t wrong_delta =

            stats.how_many_wrong_words - previous_stats.how_many_wrong_words;

        string recognition_state;

        if (false_positive_delta == 1) {

          recognition_state = " (False Positive)";

        } else if (correct_delta == 1) {

          recognition_state = " (Correct)";

        } else if (wrong_delta == 1) {

          recognition_state = " (Wrong)";

        } else {

          LOG(ERROR) << "Unexpected state in statistics";

        }

        LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score

                  << recognition_state;

        previous_stats = stats;

        tensorflow::PrintAccuracyStats(stats);

      }

    }

  }

  tensorflow::StreamingAccuracyStats stats;

  tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,

                                     time_tolerance_ms, &stats);

  tensorflow::PrintAccuracyStats(stats);

  return 0;

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qqq9668

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值