espnet代码解读(1):asr.py

位置:espnet/espnet/asr/pytorch_backend/asr.py

一、读取输入输出维度

idim_list:特征向量维数[23](20 Fbank + 3 pitch)
odim:483(汉字字符数)

    # 从jason文件中获取输入、输出维度,idim_list:特征向量维数[23], odim:483(汉字字符数)
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
    ]  # 输入维度
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])  # 输出维度

二、载入设置的模型

load_trained_modules(idim, odim, args, interface=ASRInterface)
返回带有初始化权重的模型
模型由args.model_module决定

    # 载入设置的模型
    model = load_trained_modules(idim_list[0], odim, args)

三、在model.json中写入相关参数

    # 在model.json中写入输入输出维度和.yaml文件里所有模型参数
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )

四、设置 optimizer (以adam为例)

	model_params = model.parameters()
	optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)

五、设置converter

	# CustomConverter类:返回下采样后的xs_pad, ilens, ys_pad
	converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)

六、读取数据

1、make_batchset函数从json中读取数据转换为 List[List[Tuple[str, dict]]] 格式的batch set。

make_batchset的用法:
    >>> data = {'utt1': {'category': 'A', 'input': ...},
    ...         'utt2': {'category': 'B', 'input': ...},
    ...         'utt3': {'category': 'B', 'input': ...},
    ...         'utt4': {'category': 'A', 'input': ...}}
    >>> make_batchset(data, batchsize=2, ...)
    [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
# 读取训练集数据(验证集同理)
	with open(args.train_json, "rb") as f:
    	train_json = json.load(f)["utts"]
# 构造训练数据batchset    
	train = make_batchset(
	    train_json,
	    args.batch_size,
	    args.maxlen_in,
	    args.maxlen_out,
	    args.minibatches,
	    min_batch_size=args.ngpu if args.ngpu > 1 else 1,
	    shortest_first=use_sortagrad,
	    count=args.batch_count,
	    batch_bins=args.batch_bins,
	    batch_frames_in=args.batch_frames_in,
	    batch_frames_out=args.batch_frames_out,
	    batch_frames_inout=args.batch_frames_inout,
	    iaxis=0,
	    oaxis=0,
	)

2、LoadInputsAndTargets的功能是构造mini batch,其call函数:call(self, batch, return_uttid=False) 可以从dict中提取输入特征向量(feats)和标签(targets)。
feats = [(T_1, D), (T_2, D), …, (T_B, D)]
targets = [(L_1), (L_2), …, (L_B)]

LoadInputsAndTargets用法:
>>> batch = [('utt1',
...           dict(input=[dict(feat='some.ark:123',
...                            filetype='mat',
...                            name='input1',
...                            shape=[100, 80])],
...                output=[dict(tokenid='1 2 3 4',
...                             name='target1',
...                             shape=[4, 31])]))]
>>> load_tr = LoadInputsAndTargets()
>>> feat, target = load_tr(batch)
    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf, # 检查预处理conf,如specaug
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )

3、ChainerDataLoader是一个Chainer风格的pytorch DataLoader。
TransformDataset将数据转换为Pytorch Dataset,

	class TransformDataset(torch.utils.data.Dataset):
		def __init__(self, data, transform):
		   super(TransformDataset).__init__()
		   self.data = data
		   self.transform = transform
		
		def __len__(self):
		   return len(self.data)
		
		def __getitem__(self, idx):
		   return self.transform(self.data[idx])
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )

七、设置Updater

自定义CustomUpdater,核心代码(简化后)如下:

    def update_core(self):
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
		train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")
        epoch = train_iter.epoch
        
        batch = train_iter.next()
        x = _recursive_to(batch, self.device)
        is_new_epoch = train_iter.epoch != epoch
        
        loss = (data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad)
        loss.backward()  # 反向传播
        # 正则化方法的一种:噪声注入
        if self.grad_noise:
			......

        self.forward_count += 1
        if not is_new_epoch and self.forward_count != self.accum_grad:
            return
            
        self.forward_count = 0
		# 计算grad_norm,检查梯度是否正常
		......
		optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 梯度清零
        
    def update(self):
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1
    updater = CustomUpdater(
        model,
        args.grad_clip,  # 如果在更新梯度的时候,梯度超过这个阈值,则会将其限制在这个范围之内,防止梯度爆炸。
        {"main": train_iter},  # chainer iterator
        optimizer, # 
        device,
        args.ngpu,
        args.grad_noise,  # 正则化方法的一种:噪声注入
        args.accum_grad,  # 梯度累加(默认2,即每两轮梯度清零)
        use_apex=use_apex,
    )

八、设置Chainer训练器

格式为 trainer = training.Trainer(updater, (max_epoch, ‘epoch’), out=path)

    # 设置Chainer训练器,training.Trainer(updater, (max_epoch, 'epoch'), out=path)
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

九、训练器扩展功能

	# 评估模型
    trainer.extend(CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu))
    # 每个epoch保存attention权重
    trainer.extend(att_reporter, trigger=(1, "epoch"))
    # 每个epoch保存CTC prob
    trainer.extend(ctc_reporter, trigger=(1, "epoch"))
    
    # 绘制 loss.png
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ctc",
                "validation/main/loss_ctc",
                "main/loss_att",
                "validation/main/loss_att"
            ],
            "epoch",
            file_name="loss.png",
        )
    )
	# 绘制 acc.png
    trainer.extend(
        extensions.PlotReport(
            ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
        )
    )
    # 绘制cer.png
    trainer.extend(
        extensions.PlotReport(
            ["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png",
        )
    )

	# 保存loss best模型
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    # 保存acc best模型
    trainer.extend(
    	snapshot_object(model, "model.acc.best"),
   		trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
	)
	# 每个epoch保存snapshot (用于模型平均)
	trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
	
	# 每100次迭代,在train.log中记录一次
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
    )   
    # 每100次迭代,在log中记录report_keys,包括"epoch", "iteration", "main/loss" ......
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )
	# 每100次迭代,在train.log中绘制进度条
    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))

十、设置早停

    set_early_stop(trainer, args)

十一、运行

    trainer.run()
    check_early_stop(trainer, args.epochs)
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
怎么报错应该怎么解决java.lang.IllegalArgumentException: View=com.xiaopeng.xui.widget.XLinearLayout{6842348 V.E...... ......ID 0,0-600,130} not attached to window manager 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.view.WindowManagerGlobal.findViewLocked(WindowManagerGlobal.java:543) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.view.WindowManagerGlobal.removeView(WindowManagerGlobal.java:447) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.view.WindowManagerImpl.removeView(WindowManagerImpl.java:196) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.component.asr.AsrAreaWidget.onAsrHide(AsrAreaWidget.java:50) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.model.AsrModel.notifyChanged(AsrModel.java:85) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.model.AsrModel.access$100(AsrModel.java:15) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.model.AsrModel$1.onInputText(AsrModel.java:73) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.presenter.SpeechManager$2.lambda$onInputText$0$SpeechManager$2(SpeechManager.java:172) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.xiaopeng.systemui.speech.presenter.-$$Lambda$SpeechManager$2$LNEIprveqAbFGXR19BN2ru0Bj2o.run(Unknown Source:4) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.os.Handler.handleCallback(Handler.java:938) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.os.Handler.dispatchMessage(Handler.java:99) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.os.Looper.loopOnce(Looper.java:232) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.os.Looper.loop(Looper.java:334) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at android.app.ActivityThread.main(ActivityThread.java:7985) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at java.lang.reflect.Method.invoke(Native Method) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:548) 05-26 17:48:27.970 10708 10708 E AndroidRuntime: at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1013)
最新发布
05-27

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值