深度学习模型搭建与训练
@Datawhale
Task的内容主要聚焦于CNN(卷积神经网络)的搭建与训练
CNN网络基本架构
- 输入层
- 卷积层
- 激励层
- 池化层
- Flatten操作
- 全连接层
CNN模型训练与测试
# 训练模型
model.fit(X_train, Y_train, epochs = 90, batch_size = 50, validation_data = (X_test, Y_test))
预测测试集
def extract_features(test_dir, file_ext="*.wav"):
feature = []
for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件
X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
feature.extend([mels])
return feature