2021-04-13

LSTM
在这里插入图片描述
在这里插入图片描述

import sys
sys.path.append(’…/’)

import argparse
import torch
import os
import numpy as np
from data.data import AudioDataLoader, AudioDataset
from solver.solver import Solver
from model.baseline1 import Baseline1

parser = argparse.ArgumentParser(“Datawhale NLP”)

General config

Task related

parser.add_argument(’–train-mfccjson’, type=str, default=’…/dump/data_train.json’,
help=‘Filename of train label data (json)’)
parser.add_argument(’–valid-mfccjson’, type=str, default=’…/dump/data_dev.json’,
help=‘Filename of validation label data (json)’)

Low Frame Rate (stacking and skipping frames)

parser.add_argument(’–LFR_m’, default=1, type=int,
help=‘Low Frame Rate: number of frames to stack’) #4
parser.add_argument(’–LFR_n’, default=1, type=int,
help=‘Low Frame Rate: number of frames to skip’) #3

#model

parser.add_argument(’–class_num’, default=20, type=int, # 这里这里这里#
help=‘num of dialects class’)
parser.add_argument(’–d_input’, default=60, type=int, # 这里这里这里#
help=‘Dim of encoder input (before LFR)’)
parser.add_argument(’–n_layers’, default=1, type=int, # 这里这里这里#
help=‘num of dialects class’)

save and load model

parser.add_argument(’–save-folder’, default=‘exp’,
help=‘Location to save epoch models’)
parser.add_argument(’–checkpoint’, dest=‘checkpoint’, default=0, type=int,
help=‘Enables checkpoint saving of model’)
parser.add_argument(’–continue-from’, default=’’,
help=‘Continue from checkpoint model’)
parser.add_argument(’–model-path’, default=‘mfcc_datawhale.pth.tar’, # 这里这里这里#
help=‘Location to save best validation model’)

logging

parser.add_argument(’–print-freq’, default=100, type=int,
help=‘Frequency of printing training infomation’)
parser.add_argument(’–visdom’, dest=‘visdom’, type=int, default=0,
help=‘Turn on visdom graphing’)
parser.add_argument(’–visdom_lr’, dest=‘visdom_lr’, type=int, default=0,
help=‘Turn on visdom graphing learning rate’)
parser.add_argument(’–visdom_epoch’, dest=‘visdom_epoch’, type=int, default=0,
help=‘Turn on visdom graphing each epoch’)
parser.add_argument(’–visdom-id’, default=‘Transformer training’,
help=‘Identifier for visdom run’)

Training config

parser.add_argument(’–epochs’, default=100, type=int,
help=‘Number of maximum epochs’)

parser.add_argument(’–lr’, default=0.0005, type=float,
help=‘learn rate’)

minibatch

parser.add_argument(’–shuffle’, default=1, type=int,
help=‘reshuffle the data at every epoch’)
parser.add_argument(’–batch-size’, default=16, type=int,
help=‘Batch size’)
parser.add_argument(’–batch_frames’, default=10000, type=int, #20000
help=‘Batch frames. If this is not 0, batch size will make no sense’) #15000 10000
parser.add_argument(’–maxlen-in’, default=800, type=int, metavar=‘ML’,
help=‘Batch size is reduced if the input sequence length > ML’)
parser.add_argument(’–maxlen-out’, default=150, type=int, metavar=‘ML’,
help=‘Batch size is reduced if the output sequence length > ML’)
parser.add_argument(’–num-workers’, default=0, type=int,
help=‘Number of workers to generate minibatch’)

parser.add_argument(’–model_choose’, default=‘baseline’, type=str, # 这里这里这里#
help=‘choose model type’)
#baseline4
parser.add_argument(’–layer_choose’, default=‘layer1’, type=str,
help=‘choose model type’) #layer1, layer2, layer3, layer4, all

def main(args):
# Construct Solver
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True # 保证每次结果一样
# data
tr_dataset = AudioDataset(args.train_mfccjson, args.batch_size,
args.maxlen_in, args.maxlen_out,
batch_frames=args.batch_frames)
cv_dataset = AudioDataset(args.valid_mfccjson, args.batch_size,
args.maxlen_in, args.maxlen_out,
batch_frames=args.batch_frames)
tr_loader = AudioDataLoader(tr_dataset, batch_size=1,
num_workers=args.num_workers,
shuffle=args.shuffle,
LFR_m=args.LFR_m, LFR_n=args.LFR_n, model_choose=args.model_choose)
cv_loader = AudioDataLoader(cv_dataset, batch_size=1,
num_workers=args.num_workers,
LFR_m=args.LFR_m, LFR_n=args.LFR_n, model_choose=args.model_choose)

data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

model = Baseline1(args)

print(model)
print('model parameters:', sum(param.numel() for param in model.parameters()))
model.cuda()

optimizier = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), betas=(0.9, 0.98), eps=1e-09,
                              lr=args.lr)

# solver
solver = Solver(data, model, optimizier, args)
solver.train()

if name == ‘main’:

args = parser.parse_args()
print(args)
main(args)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值