人工智能-语音入门
该博客是慕课网视频教程的笔者自我小结,原视频传送门
References:
语音增强理论与实践-[美]罗艾洲等 [译]高毅等
WAV和PCM的关系和区别
AudioSet数据集
知乎-mfcc、cmvn理解
语音处理中MFCC对应的物理含义是什么
语音相关基础知识
基础知识介绍
1. 发声机理
人的发声范围:100-10kHz
基音频率:将语音波形看做是若干个正弦波的叠加的话,所谓基音就是频率最小的那个正弦波对应的频率
男性基音频率:50-250Hz
女性基音频率:100-500Hz
[详细分析与解读]
- 肺:语音产生的主要激励源——吸气时横膈膜降低,肺中气压减小,空气通过声道和气管进入肺部;呼气时胸肋肌肉收缩,胸腔容量减小,肺部气压增大,空气通过气管流出到喉部。
- 喉:由肌肉、韧带和软骨构成,控制着声带的功能,两片声带之间的缝隙称为声门。声带处于呼吸态时空气从肺部自由流过声门;声带处于浊音态时,两片声带张力的快速增加会伴随着声门气压的快速变化,导致声带周期性闭合;声带处于清音态时,声带虽然不振动但是两片声带更紧绷和靠近,气流流经声门时会产生湍流。
- 声道:由口腔和鼻腔组成,口腔中的不同区域会根据舌头、牙齿、嘴唇以及上下颚(这些都统称为发声器官)的不同位置呈现不同形状。声道的作用类似于一个物理线性滤波器,将来自于声带送入声道的气流的波形的频谱进行重塑以产生不同的声音。
声带周期性开关的过程具体如下:声带关闭,气压积累并冲开声带,气流逐渐从零增加到最大,直到声带重新闭合,气流再降至零。
√声门闭相:声带闭合后没有气流流经声道;
√声门开启相:声带开启的时间段;
√基音周期:一个声门开启闭合往复一次的时间长度;
√基音频率:基音周期的倒数。
2. 听觉原理
外耳:声源定位(DOA),放大器(AGC)
中耳:保护耳蜗
内耳:梅尔滤波器组
人耳听声范围:20-20kHz
3. 音频与参数
技术框架与特点
1. 技术范围
2. 技术框架
在选择技术框架时应考虑到不同的应用场景和需求
语音技术演进
1. 自动语音识别
从统计学的意义上分析人类的WER水平,大概到达5.6,而当前的SOTA模型已经达到了1.4;也正因为此,ASR是众多语音技术中最先商用部署的一项。
2. 语音合成
语音合成这项技术的商用部署也已经很成功了;
所谓[拼接]类似于数据库+模板匹配的思想,而建模分析则是对人的发声过程进行模型化。
音频特征及预备知识
------------------目录---------------------------
基于Python对音频进行基本操作
Audition基本操作
谱图讲解
Mel特征抽取
神经网络相关组件
音频文件处理与操作
语音应用常用音频格式分析:WAV和PCM
1. AU的基本操作
2. 谱图分析和域的概念
频谱图实际上是一个三维图,横轴表示时间,纵轴表示频率,而颜色的深浅则表示某时刻在某一频率值的能量大小
[详细解释与介绍]
语谱图反映的是语音功率随着时间t变化的图形显示——即随时间t,语音信号在频率上的相对能量分布,S(n,ω) = |X(n,ω)|2;
根据计算S(n,ω)时采用的窗长度的大小,可以得到宽带和窄带两种语谱图。
3. Python的语音处理库
①因为人的听觉感知与幅度并不呈线性关系,因此将幅值降低一半并不等同于音量减低一半;
import numpy as np
import matplotlib.pyplot as plt
import scipy.io.wavfile as wf
class WaveProc(object):
@staticmethod
def wave_read(wave_path):
""" 单声道 wav 文件打开
:param wave_path: wav 文件路径
:return: ndarray 数据和采样率...
"""
assert wave_path.endswith('wav'), 'Not Supported File Format!'
sr, wave_data = wf.read(wave_path) # [frame, channels]:[16000, 2]
frame_num = wave_data.shape[0]
duration = frame_num / sr
if len(wave_data.shape) == 1:
channel_num = 1
else:
channel_num = 2
#涉及到更多声道的音频信息的读取可能需要用到waveio这个库
return sr, duration, channel_num, wave_data
@staticmethod
def pcm_read(pcm_path, sr):
assert pcm_path.endswith('pcm'), 'Not Supported File Format!'
wave_data = np.fromfile(pcm_path, dtype=np.short)
frame_num = wave_data.shape[0]
duration = frame_num / sr
if len(wave_data.shape) == 1:
channel_num = 1
else:
channel_num = 2
return sr, duration, channel_num, wave_data
def time_domain_display(self, wav_path):
sr, duration, channel_num, data = self.wave_read(wav_path)
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.title('Time Domain')
if channel_num == 1:
plt.plot(data)
elif channel_num == 2:
plt.subplot(211)
plt.plot(data[:, 0])
plt.subplot(212)
plt.plot(data[:, 1])
plt.show()
@staticmethod
def wave_write(wave_data, output_path, sr):
""" 写音频
:param wave_data: ndarray 的音频数据
:param output_path: 写路径
:param sr: 采样率
:return: None
"""
assert output_path.endswith('wav'), "Not Supported File Format"
wf.write(output_path, sr, wave_data)
语音特征与认知
1. 特征提取
√将语音波形以相对小的数据变化速率转换为另一种参数的形式,以便后续处理和分析;
如上图中四个语音段均为同一个人发音“1”得到的数据,因为在时域波形图中数据的变化率过快,所以希望可以在其他特征域上去分析一下参数。
2. 谱与谱图
①谱与倒谱
√谱:说到谱系数,那就是经过从时域到频谱的相关变换;
√倒谱:而所谓倒谱,就是时域变换到频域之后,又经过诸如求对数之类的运算之后变换回到时域。
②谱图
3. 音频特征
①梅尔谱(Mel谱)
√Mel尺度
其作用机理就是将频域经过如下公式的变换;处理的原因是基于人耳听觉的非线性感知,经过Mel尺度的转换,尽量使得幅度量的增长与听觉感知呈现线性关系——当Mel尺度上的值增加一倍时,人耳的感知也会相应增加一倍。
√Mel特征的抽取
①在保留音频有效信息的同时可以减少计算量;
②mel滤波器组的通道数是自选的,最后得到的mel谱的维度就是(帧数,通道数);
③通道数越大,带来的计算量也会增长,但同时对音频特征信息的刻画也会更加细致;
④在ASR任务中,40的通道数已经绰绰有余
import numpy as np
from librosa.filters import mel as librosa_mel_fn
import matplotlib.pyplot as plt
#librosa_mel_fn该模块可以设计出一组梅尔三角滤波器组
class FeatureExt(object):
def __init__(self, sr, n_mel_channels):
self.sr = sr
self.frame_size = int(25*sr/1000) # 25ms(400采样点)的窗长
self.frame_stride = int(10*sr/1000) # 10ms(160采样点)的窗移
self.n_mel_channels = n_mel_channels # 需要的mel通道个数
self.fmin, self.fmax = 0, int(sr/2)#规定的高低频率用于设置三角滤波器组的截止频率
self.NFFT = 512#傅里叶计算的点数为不小于窗长的2的幂
self.mel_bias = librosa_mel_fn(sr, self.NFFT, self.n_mel_channels, self.fmin, self.fmax)
def mel_calc(self, signal):
# A 预先处理 (opt.): 高频抬升, 系数一般选择 0.97 左右
pre_emphasis = 0.97
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
# B 分窗:语音中时域到时频域的方法
signal_length = len(emphasized_signal)
num_frames = int(np.floor(float(np.abs(signal_length -self.frame_size) ) /self.frame_stride)) + 1
pad_signal_length = num_frames * self.frame_stride + self.frame_size
z = np.zeros((pad_signal_length -signal_length))
pad_signal = np.append(emphasized_signal, z)
indices = np.tile(np.arange(0, self.frame_size), (num_frames, 1)) + np.tile \
(np.arange(0, num_frames * self.frame_stride, self.frame_stride), (self.frame_size, 1)).T
frames = pad_signal[np.mat(indices).astype(np.int32, copy=False)]
#根据帧数、帧长和帧移确定为了能够有效进行短时窗分析信号所需的总长度
# C 窗函数(opt.):抑制频谱泄漏,还有汉宁窗,矩形窗, povey窗
frames *= np.hamming(self.frame_size)
# D 短时傅里叶变换(stft): 点数选择原则: 最小2^n大于窗长的点, 512 > 400
frames_fft = np.fft.rfft(frames, self.NFFT) # stft
# E 能量谱计算
mag_frames = np.absolute(frames_fft) # Magnitude of the FFT
# F 三角mel滤波器组 apply
mel = np.dot(self.mel_bias, mag_frames.T)
print('mel谱的维度',mel.shape)
print('能量谱的维度',mag_frames.shape)
return mel
@staticmethod
def display(mel):
mel = np.flipud(mel)
plt.imshow(mel)
plt.show()
if __name__ == '__main__':
from wav_proc import *
wave_path = './data/16k-2bytes-mono.wav'
WP = WaveProc()
sr, duration, channel_num, data = WP.wave_read(wave_path)
FE = FeatureExt(16000, 80) # 40, 80
mel = FE.mel_calc(data)
print(mel.shape)
FE.display(mel)
相关神经网络组件
- 循环神经网络:RNN
- 卷积神经网络:Conv1d/Conv2d
- 注意力
- 残差
基于深度学习的音频事件监测实践
- 构建项目
- 数据准备
- 特征抽取——借助开源库来进行
- 模型搭建——与具体场景相关
- 数据导入
- 模型训练与日志编写
- Mic调用及部署
数据集与数据
1. AudioSet
Audioset DataSet是从YouTube视频中收集来并进行人为标准的时长为10秒的语音片段集合。
对数据集的标签标注依赖于YouTube元数据和基于内容的搜索。
2. 音频数据校验
在将数据导入进行训练之前,都需要对数据进行校验以确保数据的格式和形式是正确的。
import os
import wave
'''
用于对音频数据进行校验
'''
dataset_dir = '' #输入数据集的目录
#对音频数据进行校验主要是针对其通道、位宽、采样率和时长进行确定
#-----音频属性-----采样率-----通道--------位宽------------长度-----
#------值------------16k-------mono--------2 bytes---------10s-----
for root,sub_dir,items in os.walk(dataset_dir):
if not sub_dir:
for it in items:
f = wave.open(os.path.join(root,it),'rb')
params = f.getparams() #利用wave模板进行音频文件读取并进行参数解析和获取
nchannel,width,sample_rate,nframes = params[:4]
#进行校验
assert nchannel == 1 and width == 2 and sample_rate == 16000 and nframes == 10 * 16000
print(os.path.split(root)[1],'passed!')
print('dataset passed!')
对于诸如合成等其他对语音质量有着较高要求的应用,在进行音频数据校验时除了关注上述参数,还需要关注——
- 是否截幅,即幅度值是否溢出
DataLoader
在音频事件检测这个领域中,DataLoader实例主要完成以下工作:
①元数据文件(meta file)的读取和解析
②梅尔谱计算实例的导入
import torch
from torch.utils.data import DataLoader
from utils import meta_parse, load_wav_to_torch
from audio_processing import MelSpec
class MelLoader(torch.utils.data.Dataset):
def __init__(self, metafile_path, cfg):
""" mel 特征 dataset 类
:param metafile_path: 数据集metafile
:param cfg: 配置文件
"""
# meta parse
self.items = meta_parse(metafile_path) # [(wav_path, label)]
self.max_wav_value = cfg['max_wav_value'] # 32768.
self.sampling_rate = cfg['sampling_rate']
self.device = cfg['device']
# mel 计算类实例
self.mel_calc_inst = MelSpec(cfg['win_len'], # stft win: 512
cfg['hop'], # stride: 160
cfg['nfilter'], # mel calc win: 400
cfg['n_mel_channels'], # 40
cfg['sampling_rate'], # 16000
cfg['mel_fmin'], # 50Hz
cfg['mel_fmax']) # 800Hz
def get_mel(self, wav_path):
"""
:param wav_path: 音频路径
:return: mel 特征
"""
audio, sr = load_wav_to_torch(wav_path)
assert sr == self.sampling_rate, 'sample rate not match!'
audio_norm = audio / self.max_wav_value # 赋值归一化
audio_norm = audio_norm.unsqueeze(0) #[n,]→[1,n]
melspec = self.mel_calc_inst.mel_spectrogram(audio_norm) # mel计算
melspec = torch.squeeze(melspec, 0)
return melspec
def __getitem__(self, idx):
wave_path, label = self.items[idx]
return self.get_mel(wave_path).to(self.device), torch.tensor(int(label)).long().to(self.device)
def __len__(self):
return len(self.items)
特征抽取
- 配置文件编写
- utils.py介绍
- Mel spec抽取类
1. 模型参数配置与导入
# date config
# File:cfg.yaml
# ################################################################
# Data
# ################################################################
device: 'cpu' # calculate device
dataset_dir: './dataset/' # dataset directory
meta_path: './dataset/meta.csv' # dataset metafile
max_wav_value: 32768. # 16 bytes maximum value
#对于16位的位宽,取值是-32767~32768,为了归一化到-1~1,这里取32768
sampling_rate: 16000 # sampling rate
win_len: 512 # fft size 比mel滤波器组点数略大的2幂方
hop: 160 # frame stride,10ms
nfilter: 400 # frame size,对应16k的采样率就是25ms
n_mel_channels: 40 # mel channels
mel_fmin: 50 # minimum cut-off frequency
mel_fmax: 800 # maximum cut-off frequency
seed: 1234 # random seed
#File:cfg_parse.py
import yaml
cfg = yaml.safe_load(open('./cfg.yaml'))
2. MFCC的意义解读
在做理论推导时我们需要进行傅里叶反变换,在实际求解过程中我们是利用一组三角滤波器,这二者的物理意义都是相似的——表示信号频谱的能量在不同频率区间的分布;
倒谱的定义:信号经过傅里叶变换后经过对数运算之后再进行一次傅里叶反变换;
语音信号的频谱可以看做是低频的包络和高频的细节进行相加,因此在得到倒谱系数之后,只需要取低位的系数就可以得到包络信息。
3. 梅尔滤波器组计算类
import torch
from librosa.filters import mel as librosa_mel_fn
from stft import STFT
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""动态范围压缩
:param x: 输入mel
:param C: 压缩系数
:param clip_val: 避免log0
:return: 压缩变换后的特征
在提取声学特征之后,将声学特征从一个空间转变到另一个空间,使其在这个空间下更符合某种概率分布,
压缩了特征参数值域的动态范围,减少了训练和测试环境的不匹配等问题;有助于提高模型的鲁棒性,其
本质就是进行了归一化操作
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
class MelSpec(torch.nn.Module):
"""
这个类负责计算mel特征,并进行特征压缩
"""
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=40, sampling_rate=16000, mel_fmin=0.0,
mel_fmax=8000.0):
""" mel 特征抽取
:param filter_length: fft采样点数
:param hop_length: 移动 stride
:param win_length: 窗长
:param n_mel_channels: mel channel 个数
:param sampling_rate: 采样率
:param mel_fmin: 最小截止频率
:param mel_fmax: 最大截止频率
"""
super(MelSpec, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
#确定一个STFT变化的实例
self.stft_fn = STFT(filter_length=filter_length, hop_length=hop_length, win_length=win_length)
mel_bias = librosa_mel_fn(sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_bias = torch.from_numpy(mel_bias).float()
#因为mel滤波器组的参数是一组固定的参数,所以可以使用一个buffer注册
self.register_buffer('mel_bias', mel_bias)
def spectral_normalize(self, magnitudes):
output = dynamic_range_compression(magnitudes)
return output
def mel_spectrogram(self, y):
""" mel 特征计算
:param y: 幅值归一化后的音频数据
:return: mel 特征
"""
assert torch.min(y) >= -1 and torch.max(y) <= 1
magnitudes, phase = self.stft_fn.transform(y) # 傅里叶变换
magnitudes = magnitudes.data #这里直接用幅度谱进行mel滤波,也可以用能量谱
mel_output = torch.matmul(self.mel_bias, magnitudes) # apply mel 三角滤波器组
mel_output = self.spectral_normalize(mel_output) # 动态范围压缩 normalization
return mel_output
模型构建与训练
1. 模型结构
2. 模型搭建
- 超参配置
- 模块编写
- 模型编码
在
cfg.yaml
文件中要加入有关网络模型的超参设置——
# Model Structure
# ################################################################
n_pre_layer: 2 # prenet layers
pre_layer_hid_dim: 1024 # the first layer hidden size
pre_layer_out_dim: 512 # output layer hidden size
pre_layer_drop: 0.3 # prenet dropout
n_fc_layers: 2 # fcnet layers
fc_layer_dim: 1024 # the first hidden size
n_classes: 6 # classes
fc_layer_drop: 0.3 # fcnet dropout
kernel_size: 5 # kernel size:SAME
conv_dim: 512 # conv1d feature channels
n_conv_layers: 3 # conv1d layers
conv_drop: 0.3 # conv1d dropout prob
#这里为了数据可以直接过渡到下一层网络,一维卷积的特征维度和时序网络的第一维隐层维度是一致的
rnn_hid_dim: 512 # RNN hidden size
rnn_layers: 2 # rnn layers
bidirect: true # bi-directional or not
model.py
import torch
from torch import nn
from utils import mish
from cfg_parse import cfg
class PreNet(nn.Module):
"""
前置变换层,对mel特征变换
"""
def __init__(self, cfg):
super(PreNet, self).__init__()
self.cfg = cfg
self.layer_0 = nn.Linear(self.cfg['n_mel_channels'], self.cfg['pre_layer_hid_dim'])
self.layer_1 = nn.Linear(self.cfg['pre_layer_hid_dim'], self.cfg['pre_layer_out_dim'])
self.drop = nn.Dropout(self.cfg['pre_layer_drop']) # 0.3
# mel input :[bn, n_mel_channels, frames]: [bn, 40, 1001]
def forward(self, x): # x: [bn, n_mel_chs, n_frames]
x = x.permute(0, 2, 1) # [bn, n_frames, n_mel_chs]
x = self.layer_0(x)
x = mish(x) # mish(2018), relu(x=0), prelue, leakyrelue, sigmoud,
x = self.drop(x)
return self.layer_1(x) # [bn, n_frames, trans_layer_out_dim]: [bn, 1001, 512]
class FCNet(nn.Module):
"""
分类器:对最终的结果分类, 2层linear layer
"""
def __init__(self, cfg):
super(FCNet, self).__init__()
self.cfg = cfg
# 如果是双向的RNN最终的output的dimension要乘以2
in_dim = self.cfg['rnn_hid_dim'] * (2 if self.cfg['bidirect'] else 1)
self.layer_0 = nn.Linear(in_dim, self.cfg['fc_layer_dim'])
self.layer_1 = nn.Linear(self.cfg['fc_layer_dim'], self.cfg['n_classes'])
self.drop = nn.Dropout(self.cfg['fc_layer_drop'])
def forward(self, x): # x: [bn, 1024/512]
x = self.layer_0(x)
x = mish(x) #2018年提出的激活函数,兼顾ReLU优点的同时连续可导
x = self.drop(x)
return self.layer_1(x) # [bn, 6]
class ConvNorm(torch.nn.Module):
"""
对conv1d的简单封装,主要是权重初始化
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation, bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class Model(nn.Module):
"""
模型构建
"""
def __init__(self, cfg):
super(Model, self).__init__()
self.cfg = cfg
# 前置网络
self.prenet = PreNet(self.cfg)
# 3层 conv1d 对转换后的特征进行短距离特征抽取
conv_layers = []
for _ in range(self.cfg['n_conv_layers']):
conv_layer = nn.Sequential(
ConvNorm(in_channels=self.cfg['conv_dim'],
out_channels=self.cfg['conv_dim'],
kernel_size=self.cfg['kernel_size'],
padding=int(self.cfg['kernel_size']//2)),
nn.BatchNorm1d(self.cfg['conv_dim'])
)
conv_layers.append(conv_layer)
self.convolution1ds = nn.ModuleList(conv_layers) # list
self.convDrop = nn.Dropout(self.cfg['conv_drop'])
# 用双向gru再次进行长距离特征抽取
self.gru = nn.GRU(input_size=self.cfg['conv_dim'],
hidden_size=self.cfg['rnn_hid_dim'],
num_layers=self.cfg['rnn_layers'],
bidirectional=self.cfg['bidirect'])
# 分类器
self.fc = FCNet(self.cfg)
def forward(self, x):
#注意以下有好几次维度转换
x = self.prenet(x) # [bn, n_frames, chs] : [bn, n_frames, chs]
x = x.permute(0, 2, 1) # [bn, chs, n_frames]
for conv in self.convolution1ds:
x = self.convDrop(mish(conv(x))) # [bn, chs, n_frames]
#GRU网络所需的数据维度是与其内存运行方式相匹配的
x = x.permute(2, 0, 1) # [n_frames, bn, chs]
rnn_out, hn = self.gru(x) # [n_frames, bn, channel]: [1001, bn, 1024/512]
#在进行最后的分类时,只需要把GRU最后一个单元输给FC分类器,因为它承载了前文所有时序的信息
out = self.fc(rnn_out[-1]) # [bn, 6]
return out
3. 模型训练
- 训练超参配置
- trainer编写
- 超参分析
import os
from argparse import ArgumentParser
import torch.optim as optim
import torch
import random
import numpy as np
import torch.nn as nn
from model import Model
from cfg_parse import cfg
from dataset_se import MelLoader
from torch.utils.data import DataLoader, Subset
from tensorboardX import SummaryWriter
from collections import Counter
logger = SummaryWriter('./log') # 训练 log存放地址
# seed init: Ensure Reproducible Result
seed = 123
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# 每次验证的时候,均衡的从evaluate set 中抽取
def balance_sample(dataset_):
assert isinstance(dataset_, MelLoader)
items = dataset_.items
random.shuffle(items)
sample_indices_dict = {str(k): [] for k in range(6)}
for idx, (_, lab) in enumerate(items):
if len(sample_indices_dict[lab]) < 20:
sample_indices_dict[lab].append(idx)
sample_indices = []
[sample_indices.extend(it) for it in list(sample_indices_dict.values())]
return sample_indices
# 计算输出类别的比例,用于调整 Weighted CE
def ratio_calc(lst):
num = len(lst)
lst_dict = Counter(lst)
ret = []
for lab in range(6):
if lab not in lst_dict:
ret.append(0)
else:
ret.append(lst_dict[lab]/num)
ret = ['%.3f' % it for it in ret]
return ret
ce_fw = open('./log/ce_weigh.log', 'a+')
# 从evaluate set中均衡抽取一个subset, 计算各类别的比例
def evaluate(model_, valset, crit):
model_.eval()
subset_indices = balance_sample(valset)
subset = Subset(valset, subset_indices)
val_loader = DataLoader(subset, batch_size=1, shuffle=True)
sum_loss = 0.
with torch.no_grad():
y_actual_list, y_pred_list = [], []
for batch in val_loader:
inputs, lab = batch
inputs, lab = inputs.to(cfg['device']), lab.to(cfg['device'])
pred = model_(inputs)
loss = crit(pred, lab)
sum_loss += loss.item()
y_actual_list.append(lab.item())
y_pred_list.append(pred.squeeze().argmax().item())
ce_weight_log = 'y ratios:'+':'.join(ratio_calc(y_actual_list))+'\n' + \
'y_pred ratios:'+':'.join(ratio_calc(y_pred_list))+'\n'
print(ce_weight_log)
ce_fw.write(ce_weight_log)
ce_fw.flush()
model_.train()
return sum_loss/len(val_loader)
# 保存模型
def save_checkpoint(model_, epoch_, optm, checkpoint_path):
save_dict = {
'epoch': epoch_,
'model_state_dict': model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
}
torch.save(save_dict, checkpoint_path)
# 训练
def train():
parser = ArgumentParser(description='Model Train')
parser.add_argument(
'--train_meta_path', # train.csv
type=str,
help='train meta csv'
)
parser.add_argument(
'--eval_meta_path', # eval.csv
type=str,
help='eval meta csv'
)
parser.add_argument( # 可以从某个checkpoint恢复训练
'--c', # checkpoint path
default=None,
type=str,
help='train from scratch if it is none, or resume training from checkpoint'
)
args = parser.parse_args()
model = Model(cfg)
# 根据第一次训练计算的weighted cross-entropy, 调整的weighted ce
weights = [0.8867, 1.1350, 0.9683, 0.9632, 0.97508, 1.1316]
t_weights = torch.FloatTensor(weights).to(cfg['device'])
criterion = nn.CrossEntropyLoss(weight=t_weights)
opt = optim.Adam(model.parameters(), lr=cfg['lr'])
trainset = MelLoader(args.train_meta_path, cfg)
train_loader = DataLoader(trainset, batch_size=cfg['batch_size'], shuffle=True, drop_last=True)
evalset = MelLoader(args.eval_meta_path, cfg)
start_epoch = 0
iteration = 0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
iteration = start_epoch*len(train_loader)
print('Resume training from %s' % args.c)
else:
print('trainig from scratch!')
model = model.to(cfg['device'])
model.train() # drop=0, train_flag = True
# 主循环
for epoch in range(start_epoch, cfg['epoch']):
print('='*33, 'Start Epoch %d, Total: %d iters' % (epoch, len(trainset)/cfg['batch_size']), '='*33)
sum_loss = 0.
for batch in train_loader:
inputs, lab = batch
inputs, lab = inputs.to(cfg['device']), lab.to(cfg['device'])
opt.zero_grad()
pred = model(inputs)
loss = criterion(pred, lab)
sum_loss + loss.item()
loss.backward()
opt.step()
logger.add_scalar('Loss/Train', loss, iteration)
if not iteration % cfg['verbose_step']:
eval_loss = evaluate(model, evalset, criterion)
logger.add_scalar('Loss/Eval', eval_loss, iteration)
print('Train Loss: %.4f, Eval Loss: %.4f' % (sum_loss / cfg['verbose_step'], eval_loss))
if not iteration % cfg['save_step']:
model_path = 'model_%d_%d.pth' % (epoch, iteration)
save_checkpoint(model, epoch, opt, os.path.join('model_save', model_path))
iteration += 1
logger.flush()
print('Epoch: [%d/%d], step: %d Train Loss: %.4f' % (epoch, cfg['epoch'], iteration, loss.item()))
logger.close()
if __name__ == '__main__':
train()