HDFS中心缓存管理

本文详细介绍了HDFS的中心缓存管理,包括适用场景、结构设计和管理机制。适用于公共资源文件和短期热点数据文件的缓存,通过CacheDirective和CachePool进行管理,由CacheManager和CacheReplicationMonitor服务协同工作,实现高效的数据缓存和监控。

前言


众所周知,HDFS作为一个分布式文件系统.存储着海量的数据,每天的IO读写操作次数当然是非常高的.所以在之前的文章中,我们提到了用HDFS的异构存储来做冷热数据的分类存储,但比较好的一点是,他们还是隶属于同一个集群.那么问题来了,是否我还可以做进一步的改进,优化呢,因为有的数据文件访问在某个时间段是大家公用的,访问频率甚至比一般的热点文件还要高很多.但是过了那个时间点,就又会变为普通的文件.本文就来分享HDFS对于这一需求点的解决方案,HDFS中心缓存管理.这一方面的功能属性,可能也被很多人所忽视了.

HDFS缓存适用场景


首先,我们先要了解HDFS的缓存所适用的场景,换句话说,他能解决我们哪些具体的问题.

缓存HDFS中的热点公共资源文件和短期临时的热点数据文件

第一种case: 公共资源文件. 这些文件可以是一些存放于HDFS中的依赖资源jar包,或是一些算法学习依赖.so文件等等.像这类的数据文件,放在HDFS上的好处是,我可以在HDFS上全局共享嘛,不用到本地机器上去依赖,而且好管理,我可以直接更新到HDFS上.但是这种场景更好的做法是把它做成distributed cache,否则在程序中将会发送大量的请求到NameNode中去获取这些资源文件的和内容.而且这种请求量是非常恐怖的,不是说请求一次就够了,而是调用一次,请求一次.

第二种case: 短期临时的热点数据文件.比如集群中每天需要统计的报表数据,需要读取前一天的或是最近一周的数据做离线分析,但是过了这个期限内的基本就很少再用到了,就可以视为冷数据了.那么这个时候就可以把符合这个时间段的数据做缓存处理,然后过期了,就直接从缓存中清除.

以上2种场景,都是HDFS Cache非常适用的场景.

HDFS缓存的结构设计


在HDFS中,最终缓存的本质上还是一个INodeFile文件.但是在逻辑上,引出了下面几个概念.

CacheDirective


CacheDirective是缓存的基本单元,但是这里cacheDirective不一定是一个目录,也可以是一个文件.其中包括以下主要的变量:

public final class CacheDirective implements IntrusiveCollection.Element {
   
   
  // 惟一标识Id
  private final long id;
  // 目标缓存路径
  private final String path;
  // 对应路径的文件副本数
  private final short replication;
  // 所属CachePool
  private CachePool pool;
  // 过期时间
  private final long expiryTime;

  // 相关统计指标
  private long bytesNeeded;
  private long bytesCached;
  private long filesNeeded;
  private long filesCached;
  ...

在这里,我们看到了一个新的概念,CachePool,可以得出下面一个结论:

CacheDirective属于对应的CachePool缓存池

CachePool


下面就是CachePool概念的定义了

public final class CachePool {
   
   
  // 缓存池名称
  @Nonnull
  private final String poolName;
  // 所属用户名
  @Nonnull
  private String ownerName;
  // 所属组名
  @Nonnull
  private String groupName;
  // 缓存池权限
  /**
   * Cache pool permissions.
   * 
   * READ permission means that you can list the cache directives in this pool.
   * WRITE permission means that you can add, remove, or modify cache directives
   *       in this pool.
   * EXECUTE permission is unused.
   */
  @Nonnull
  private FsPermission mode;
  // 缓存池最大允许缓存字节数
  /**
   * Maximum number of bytes that can be cached in this pool.
   */
  private long limit;
  // 过期时间
  /**
   * Maximum duration that a CacheDirective in this pool remains valid,
   * in milliseconds.
   */
  private long maxRelativeExpiryMs;
  // 变量统计相关值
  private long bytesNeeded;
  private long bytesCached;
  private long filesNeeded;
  private long filesCached;
  ...
  // 缓存对象列表
  @Nonnull
  private final DirectiveList directiveList = new DirectiveList(this);
  ...

我们可以看到,在cachePool中,也确实维护了一个cacheDirective缓存单元列表.而这些cachePool缓存池则是被CacheManager所掌管.CacheManager在这里就好比是一个总管理者的角色.当然,在CacheManager中还有运行着一个很重要的服务,就是CacheReplicationMonitor,这个监控程序会周期扫描当前的最新的缓存路径,并分发到对应的DataNode节点上,这个线程服务在后面还会具体提到.所以HDFS Cache的总的结构关系如下图所示:

这里写图片描述

HDFS缓存管理机制分析


其实之前本人已经写过一篇关于HDFS缓存管理机制方面的文章,HDFS缓存机制,但是现在来看,还并不是很全面,之前的文章完全遗漏了CacheAdmin这块的东西.所以在本小节中,要补充2方面的分析介绍.

  • CacheAdmin CLI命令在CacheManager的实现
  • CacheMAnager的CacheReplicationMonitor如何将目标缓存文件缓存到DataNode中

下面先来看第一点涉及到的内容

CacheAdmin CLI命令在CacheManager的实现


在CacheAdmin中的每个操作命令,最后通过RPC调用都会对应到CacheManager的一个具体操作方法.所以在此过程中,要解决下面几个主要疑点:

  • CacheManager维护了怎样的CachePool列表,CacheDirective关系
  • 添加新的CacheDirective,CachePool有哪些特殊的细节

对于第一个问题,CacheManager确实维护了多种映射关系的CachePool,CacheDirective列表关系,如下:

public final class CacheManager {
   
   
  ...
  // CacheDirective id对CacheDirective的映射关系
  /**
   * Cache directives, sorted by ID.
   *
   * listCacheDirectives relies on the ordering of elements in this map
   * to track what has already been listed by the client.
   */
  private final TreeMap<Long, CacheDirective> directivesById =
      new TreeMap<Long, CacheDirective>();
  //缓存路径对CacheDirective列表的映射关系,说明一个文件/目录路径可以同时被多次缓存
  ...
  /**
   * Cache directives, sorted by path
   */
  private final TreeMap<String, List<CacheDirective>> directivesByPath =
      new TreeMap<String, List<CacheDirective>>();
  // 缓存池名称对CachePool的映射
  /**
   * Cache pools, sorted by name.
   */
  private final TreeMap<String, CachePool> cachePools =
      new TreeMap<String, CachePool>();
  ...

以上的3大映射关系就是CacheManager对象中存储着的,第二条关系缓存路径对缓存对象列表的映射是一开始我感到奇怪的,后来发现,对同一个缓存路径,是可以多次缓存的.由于定义了这3类结构关系,所以在添加CacheDirective实例对象时候会涉及到一些更新操作.以addDirective方法为例

  public CacheDirectiveInfo addDirective(
      CacheDirectiveInfo info, FSPermissionChecker pc, EnumSet<CacheFlag> flags)
      throws IOException {
    assert namesystem.hasWriteLock();
    CacheDirective directive;
    try {
      // 获取所属缓存池
      CachePool pool = getCachePool(validatePoolName(info));
      
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) # 2022 Shaoqing Yu(954793264@qq.com) # 2023 Jing Du(thuduj12@163.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import argparse import copy import logging import os import sys import math import torch import yaml from collections import defaultdict from torch.utils.data import DataLoader from wekws.dataset.init_dataset import init_dataset from wekws.model.kws_model import init_model from wekws.utils.checkpoint import load_checkpoint from wenet.text.char_tokenizer import CharTokenizer def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--test_data', required=True, help='test data file') parser.add_argument('--dict', default='./dict', help='dict dir') parser.add_argument('--gpu', type=int, default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--batch_size', default=1, type=int, help='batch size for inference') parser.add_argument('--num_workers', default=1, type=int, help='num of subprocess workers for reading') parser.add_argument('--pin_memory', action='store_true', default=False, help='Use pinned memory buffers used for reading') parser.add_argument('--prefetch', default=100, type=int, help='prefetch number') parser.add_argument('--score_file', required=True, help='output score file') parser.add_argument('--jit_model', action='store_true', default=False, help='Use pinned memory buffers used for reading') parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, help='The first prune beam, f' 'ilter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, help='The second prune beam, ' 'keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, help='The threshold of kws. ' 'If ctc_search probs exceed this value,' 'the keyword will be activated.') parser.add_argument('--min_frames', default=5, type=int, help='The min frames of keyword duration.') parser.add_argument('--max_frames', default=250, type=int, help='The max frames of keyword duration.') args = parser.parse_args() return args def is_sublist(main_list, check_list): if len(main_list) < len(check_list): return -1 if len(main_list) == len(check_list): return 0 if main_list == check_list else -1 for i in range(len(main_list) - len(check_list)): if main_list[i] == check_list[0]: for j in range(len(check_list)): if main_list[i + j] != check_list[j]: break else: return i else: return -1 def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') torch.cuda.set_device(args.gpu) with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 test_conf['filter_conf']['min_length'] = 0 test_conf['filter_conf']['token_max_length'] = 10240 test_conf['filter_conf']['token_min_length'] = 1 test_conf['filter_conf']['min_output_input_ratio'] = 1e-6 test_conf['filter_conf']['max_output_input_ratio'] = 1 test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['shuffle'] = False feats_type = test_conf.get('feats_type', 'fbank') test_conf[f'{feats_type}_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_size'] = args.batch_size downsampling_factor = test_conf.get('frame_skip', 1) tokenizer = CharTokenizer(f'{args.dict}/dict.txt', f'{args.dict}/words.txt', unk='<filler>', split_with_space=True) test_dataset = init_dataset(data_list_file=args.test_data, conf=test_conf, tokenizer=tokenizer, split='test') test_data_loader = DataLoader(test_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) if args.jit_model: model = torch.jit.load(args.checkpoint) # For script model, only cpu is supported. device = torch.device('cpu') else: # Init asr model from configs model = init_model(configs['model']) load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) model.eval() score_abs_path = os.path.abspath(args.score_file) # 4. parse keywords tokens assert args.keywords is not None, 'at least one keyword is needed' logging.info(f"keywords is {args.keywords}, " f"Chinese is converted into Unicode.") keywords_str = args.keywords.encode('utf-8').decode('unicode_escape') keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} keywords_idxset = {0} keywords_strset = {'<blk>'} keywords_tokenmap = {'<blk>': 0} for keyword in keywords_list: strs, indexes = tokenizer.tokenize(' '.join(list(keyword))) indexes = tuple(indexes) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexes) [keywords_strset.add(i) for i in strs] [keywords_idxset.add(i) for i in indexes] for txt, idx in zip(strs, indexes): if keywords_tokenmap.get(txt, None) is None: keywords_tokenmap[txt] = idx token_print = '' for txt, idx in keywords_tokenmap.items(): token_print += f'{txt}({idx}) ' logging.info(f'Token set is: {token_print}') with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: for batch_idx, batch_dict in enumerate(test_data_loader): keys = batch_dict['keys'] feats = batch_dict['feats'] targets = batch_dict['target'][:, 0] lengths = batch_dict['feats_lengths'] label_lengths = batch_dict['target_lengths'] feats = feats.to(device) lengths = lengths.to(device) logits, _ = model(feats) logits = logits.softmax(2) # (batch_size, maxlen, vocab_size) logits = logits.cpu() for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] # hyps = ctc_prefix_beam_search(score, lengths[i], # keywords_idxset) maxlen = score.size(0) ctc_probs = score cur_hyps = [(tuple(), (1.0, 0.0, []))] hit_keyword = None activated = False hit_score = 1.0 start = 0 end = 0 # 2. CTC beam search step by step for t in range(0, maxlen): probs = ctc_probs[t] # (vocab_size,) t *= downsampling_factor # the real time # key: prefix, value (pb, pnb), default value(-inf, -inf) next_hyps = defaultdict(lambda: (0.0, 0.0, [])) # 2.1 First beam prune: select topk best top_k_probs, top_k_index = probs.topk(args.score_beam_size) # filter prob score that is too small filter_probs = [] filter_index = [] for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): if keywords_idxset is not None: if prob > 0.05 and idx in keywords_idxset: filter_probs.append(prob) filter_index.append(idx) else: if prob > 0.05: filter_probs.append(prob) filter_index.append(idx) if len(filter_index) == 0: continue for s in filter_index: ps = probs[s].item() for prefix, (pb, pnb, cur_nodes) in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == 0: # blank n_pb, n_pnb, nodes = next_hyps[prefix] n_pb = n_pb + pb * ps + pnb * ps nodes = cur_nodes.copy() next_hyps[prefix] = (n_pb, n_pnb, nodes) elif s == last: if not math.isclose(pnb, 0.0, abs_tol=0.000001): # Update *ss -> *s; n_pb, n_pnb, nodes = next_hyps[prefix] n_pnb = n_pnb + pnb * ps nodes = cur_nodes.copy() # update frame and prob if ps > nodes[-1]['prob']: nodes[-1]['prob'] = ps nodes[-1]['frame'] = t next_hyps[prefix] = (n_pb, n_pnb, nodes) if not math.isclose(pb, 0.0, abs_tol=0.000001): # Update *s-s -> *ss, - is for blank n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pnb = n_pnb + pb * ps nodes = cur_nodes.copy() nodes.append( dict(token=s, frame=t, prob=ps)) next_hyps[n_prefix] = (n_pb, n_pnb, nodes) else: n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] if nodes: # update frame and prob if ps > nodes[-1]['prob']: # nodes[-1]['prob'] = ps # nodes[-1]['frame'] = t # avoid change other beam has this node. nodes.pop() nodes.append( dict(token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() nodes.append( dict(token=s, frame=t, prob=ps)) n_pnb = n_pnb + pb * ps + pnb * ps next_hyps[n_prefix] = (n_pb, n_pnb, nodes) # 2.2 Second beam prune next_hyps = sorted(next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) cur_hyps = next_hyps[:args.path_beam_size] hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] for one_hyp in hyps: prefix_ids = one_hyp[0] # path_score = one_hyp[1] prefix_nodes = one_hyp[2] assert len(prefix_ids) == len(prefix_nodes) for word in keywords_token.keys(): lab = keywords_token[word]['token_id'] offset = is_sublist(prefix_ids, lab) if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] end = prefix_nodes[offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break if hit_keyword is not None: hit_score = math.sqrt(hit_score) break duration = end - start if hit_keyword is not None: if hit_score >= args.threshold and \ args.min_frames <= duration <= args.max_frames: activated = True fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) logging.info( f"batch:{batch_idx}_{i} detect {hit_keyword} " f"in {key} from {start} to {end} frame. " f"duration {duration}, s" f"core {hit_score} Activated.") # clear the ctc_prefix buffer, and hit_keyword cur_hyps = [(tuple(), (1.0, 0.0, []))] hit_keyword = None hit_score = 1.0 elif hit_score < args.threshold: logging.info( f"batch:{batch_idx}_{i} detect {hit_keyword} " f"in {key} from {start} to {end} frame. " f"but {hit_score} less than " f"{args.threshold}, Deactivated. ") elif args.min_frames > duration \ or duration > args.max_frames: logging.info( f"batch:{batch_idx}_{i} detect {hit_keyword} " f"in {key} from {start} to {end} frame. " f"but {duration} beyond " f"range({args.min_frames}~{args.max_frames}), " f"Deactivated. ") if not activated: fout.write('{} rejected\n'.format(key)) logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) sys.stdout.flush() if __name__ == '__main__': main()
最新发布
07-10
以下是提供的 Python 代码的逐行详细解释,涵盖其功能、技术实现和逻辑流程。 ### 功能概述 该脚本用于关键词语音识别(KWS)任务,基于训练好的模型对音频数据进行推理,并输出检测到的关键词及其置信度。它使用了 CTC 损失函数所对应的解码策略,结合 Beam Search 和阈值判断机制,从语音特征中提取出目标关键词并判断是否激活。 --- ### 导入模块与配置 ```python from __future__ import print_function ``` 启用 `print_function` 特性,使得在 Python 2 中也可以使用 `print()` 函数语法[^1]。 ```python import argparse import copy import logging import os import sys import math ``` 导入标准库模块: - `argparse`:用于解析命令行参数。 - `copy`:用于深拷贝对象。 - `logging`:日志记录工具。 - `os`:操作系统路径操作。 - `sys`:系统特定功能,如标准输入/输出控制。 - `math`:数学运算函数,如浮点数比较。 ```python import torch import yaml from collections import defaultdict from torch.utils.data import DataLoader ``` 引入深度学习框架及辅助库: - `torch`:PyTorch 核心模块,提供张量计算和自动求导。 - `yaml`:读取 YAML 配置文件。 - `defaultdict`:默认值字典结构,用于 Beam Search 过程中的前缀管理。 - `DataLoader`:封装数据集,支持批量加载与多线程预加载。 ```python from wekws.dataset.init_dataset import init_dataset from wekws.model.kws_model import init_model from wekws.utils.checkpoint import load_checkpoint from wenet.text.char_tokenizer import CharTokenizer ``` 引入自定义模块: - `init_dataset`:初始化测试数据集。 - `init_model`:根据配置构建 KWS 模型。 - `load_checkpoint`:加载训练好的模型权重。 - `CharTokenizer`:字符级分词器,用于将文本转换为 token ID。 --- ### 参数解析函数 ```python def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') ``` 定义命令行参数解析器,描述用途为“使用你的模型进行识别”。 ```python parser.add_argument('--config', required=True, help='config file') parser.add_argument('--test_data', required=True, help='test data file') parser.add_argument('--dict', default='./dict', help='dict dir') parser.add_argument('--gpu', type=int, default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--batch_size', default=1, type=int, help='batch size for inference') parser.add_argument('--num_workers', default=1, type=int, help='num of subprocess workers for reading') parser.add_argument('--pin_memory', action='store_true', default=False, help='Use pinned memory buffers used for reading') parser.add_argument('--prefetch', default=100, type=int, help='prefetch number') parser.add_argument('--score_file', required=True, help='output score file') parser.add_argument('--jit_model', action='store_true', default=False, help='Use pinned memory buffers used for reading') parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, help='The first prune beam, filter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, help='The second prune beam, keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, help='The threshold of kws. If ctc_search probs exceed this value, the keyword will be activated.') parser.add_argument('--min_frames', default=5, type=int, help='The min frames of keyword duration.') parser.add_argument('--max_frames', default=250, type=int, help='The max frames of keyword duration.') ``` 定义一系列运行时参数,包括: - 配置文件路径、测试数据路径、字典目录等基础路径。 - GPU 设备编号、模型加载方式(JIT 或非 JIT)、批大小等性能相关参数。 - 输出评分文件路径、Beam Search 尺寸、关键词列表、帧数限制等核心推理参数。 ```python args = parser.parse_args() return args ``` 返回解析后的命令行参数对象。 --- ### 辅助函数 ```python def is_sublist(main_list, check_list): if len(main_list) < len(check_list): return -1 ``` 判断 `check_list` 是否是 `main_list` 的子序列。如果长度不足,直接返回 `-1`。 ```python if len(main_list) == len(check_list): return 0 if main_list == check_list else -1 ``` 若两者长度相等,直接比较内容,相同则返回起始索引 `0`,否则返回 `-1`。 ```python for i in range(len(main_list) - len(check_list)): if main_list[i] == check_list[0]: for j in range(len(check_list)): if main_list[i + j] != check_list[j]: break else: return i else: return -1 ``` 遍历主列表,寻找与检查列表首元素匹配的位置,再逐一比对后续元素,若全部匹配则返回起始位置;否则继续查找,最终未找到则返回 `-1`。 --- ### 主函数逻辑 ```python def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') torch.cuda.set_device(args.gpu) ``` 获取参数并设置日志级别为 DEBUG,指定当前使用的 CUDA 设备。 ```python with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) ``` 读取配置文件并解析为 Python 字典对象。 ```python test_conf = copy.deepcopy(configs['dataset_conf']) ``` 复制训练配置中的数据集部分作为测试配置。 ```python test_conf['filter_conf']['max_length'] = 102400 test_conf['filter_conf']['min_length'] = 0 test_conf['filter_conf']['token_max_length'] = 10240 test_conf['filter_conf']['token_min_length'] = 1 test_conf['filter_conf']['min_output_input_ratio'] = 1e-6 test_conf['filter_conf']['max_output_input_ratio'] = 1 ``` 设置测试阶段的数据过滤条件,放宽长度限制以适应所有可能输入。 ```python test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['shuffle'] = False ``` 关闭数据增强和打乱操作,确保推理过程可重复。 ```python feats_type = test_conf.get('feats_type', 'fbank') test_conf[f'{feats_type}_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_size'] = args.batch_size ``` 禁用特征抖动,并设置批大小。 ```python downsampling_factor = test_conf.get('frame_skip', 1) ``` 获取下采样因子,用于调整帧时间戳。 ```python tokenizer = CharTokenizer(f'{args.dict}/dict.txt', f'{args.dict}/words.txt', unk='<filler>', split_with_space=True) ``` 初始化字符级分词器,用于将文本映射为 token ID。 ```python test_dataset = init_dataset(data_list_file=args.test_data, conf=test_conf, tokenizer=tokenizer, split='test') ``` 根据配置构建测试数据集。 ```python test_data_loader = DataLoader(test_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, prefetch_factor=args.prefetch) ``` 创建数据加载器,支持多线程读取和内存锁定。 ```python if args.jit_model: model = torch.jit.load(args.checkpoint) device = torch.device('cpu') else: model = init_model(configs['model']) load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') ``` 根据是否使用 TorchScript 加载模型,并确定设备(CPU/GPU)。 ```python model = model.to(device) model.eval() ``` 将模型移至对应设备并设为评估模式。 ```python score_abs_path = os.path.abspath(args.score_file) ``` 获取输出评分文件的绝对路径。 --- ### 关键词处理 ```python assert args.keywords is not None, 'at least one keyword is needed' logging.info(f"keywords is {args.keywords}, Chinese is converted into Unicode.") ``` 确保至少有一个关键词传入,并记录关键词信息。 ```python keywords_str = args.keywords.encode('utf-8').decode('unicode_escape') keywords_list = keywords_str.strip().replace(' ', '').split(',') ``` 将关键词字符串编码后解码为 Unicode 并按逗号分割。 ```python keywords_token = {} keywords_idxset = {0} keywords_strset = {'<blk>'} keywords_tokenmap = {'<blk>': 0} ``` 初始化关键词 token 映射结构。 ```python for keyword in keywords_list: strs, indexes = tokenizer.tokenize(' '.join(list(keyword))) indexes = tuple(indexes) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexes) [keywords_strset.add(i) for i in strs] [keywords_idxset.add(i) for i in indexes] for txt, idx in zip(strs, indexes): if keywords_tokenmap.get(txt, None) is None: keywords_tokenmap[txt] = idx ``` 对每个关键词进行分词并建立 token 映射表。 ```python token_print = '' for txt, idx in keywords_tokenmap.items(): token_print += f'{txt}({idx}) ' logging.info(f'Token set is: {token_print}') ``` 记录 token 映射信息。 --- ### 推理循环 ```python with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: for batch_idx, batch_dict in enumerate(test_data_loader): keys = batch_dict['keys'] feats = batch_dict['feats'] targets = batch_dict['target'][:, 0] lengths = batch_dict['feats_lengths'] label_lengths = batch_dict['target_lengths'] feats = feats.to(device) lengths = lengths.to(device) logits, _ = model(feats) logits = logits.softmax(2) logits = logits.cpu() ``` 进入无梯度推理模式,加载每一批次数据并送入模型,得到 softmax 后的概率分布。 ```python for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] maxlen = score.size(0) ctc_probs = score cur_hyps = [(tuple(), (1.0, 0.0, []))] ``` 对每个样本进行处理,初始化 Beam Search 状态。 ```python hit_keyword = None activated = False hit_score = 1.0 start = 0 end = 0 ``` 初始化检测结果变量。 ```python for t in range(0, maxlen): probs = ctc_probs[t] t *= downsampling_factor next_hyps = defaultdict(lambda: (0.0, 0.0, [])) top_k_probs, top_k_index = probs.topk(args.score_beam_size) ``` 对每一帧执行 Beam Search,首先进行第一次剪枝。 ```python filter_probs = [] filter_index = [] for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): if keywords_idxset is not None: if prob > 0.05 and idx in keywords_idxset: filter_probs.append(prob) filter_index.append(idx) else: if prob > 0.05: filter_probs.append(prob) filter_index.append(idx) ``` 过滤低概率 token,保留可能构成关键词的部分。 ```python if len(filter_index) == 0: continue ``` 如果没有有效 token,跳过该帧。 ```python for s in filter_index: ps = probs[s].item() for prefix, (pb, pnb, cur_nodes) in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == 0: n_pb, n_pnb, nodes = next_hyps[prefix] n_pb = n_pb + pb * ps + pnb * ps nodes = cur_nodes.copy() next_hyps[prefix] = (n_pb, n_pnb, nodes) elif s == last: ... else: ... ``` 根据 CTC 解码规则更新候选路径。 ```python next_hyps = sorted(next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) cur_hyps = next_hyps[:args.path_beam_size] ``` 第二次剪枝,保留最优路径。 ```python hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] ``` 整理当前假设路径。 ```python for one_hyp in hyps: prefix_ids = one_hyp[0] prefix_nodes = one_hyp[2] assert len(prefix_ids) == len(prefix_nodes) for word in keywords_token.keys(): lab = keywords_token[word]['token_id'] offset = is_sublist(prefix_ids, lab) if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] end = prefix_nodes[offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break if hit_keyword is not None: hit_score = math.sqrt(hit_score) break ``` 检测是否存在完整关键词,若有则计算整体得分。 ```python duration = end - start if hit_keyword is not None: if hit_score >= args.threshold and \ args.min_frames <= duration <= args.max_frames: activated = True fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) logging.info(...) elif hit_score < args.threshold: logging.info(...) elif args.min_frames > duration or duration > args.max_frames: logging.info(...) if not activated: fout.write('{} rejected\n'.format(key)) logging.info(...) ``` 根据得分和持续时间判断是否激活关键词,并写入结果。 ```python if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) sys.stdout.flush() ``` 每处理 10 批数据输出进度。 --- ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值