深度学习模型搭建与训练

01 前情摘要
环境出现问题故借鉴github
前面的task2与task3讲解了音频数据的分析以及特征提取等内容,本次任务主要是讲解CNN模型的搭建与训练,由于模型训练需要用到之前的特侦提取等得让,于是在此再贴一下相关代码。

1.1 导包
In [1]:
#基本库
import pandas as pd
import numpy as np
pd.plotting.register_matplotlib_converters()
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler

#深度学习框架
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import tensorflow as tf
import tensorflow.keras

#音频处理库
import os
import librosa
import librosa.display
import glob
/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
from numpy.core.umath_tests import inner1d
1.2 特征提取以及数据集的建立
In [2]:
feature = []
label = []

建立类别标签,不同类别对应不同的数字。

label_dict = {‘aloe’: 0, ‘burger’: 1, ‘cabbage’: 2,‘candied_fruits’:3, ‘carrots’: 4, ‘chips’:5,
‘chocolate’: 6, ‘drinks’: 7, ‘fries’: 8, ‘grapes’: 9, ‘gummies’: 10, ‘ice-cream’:11,
‘jelly’: 12, ‘noodles’: 13, ‘pickles’: 14, ‘pizza’: 15, ‘ribs’: 16, ‘salmon’:17,
‘soup’: 18, ‘wings’: 19}
label_dict_inv = {v:k for k,v in label_dict.items()}
建立提取音频特征的函数

In [3]:
from tqdm import tqdm
def extract_features(parent_dir, sub_dirs, max_file=10, file_ext="*.wav"):
c = 0
label, feature = [], []
for sub_dir in sub_dirs:
for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件

       # segment_log_specgrams, segment_labels = [], []
        #sound_clip,sr = librosa.load(fn)
        #print(fn)
        label_name = fn.split('/')[-2]
        label.extend([label_dict[label_name]])
        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
        feature.extend([mels])
        
return [feature, label]

In [4]:

自己更改目录

parent_dir = ‘./train_sample/’
save_dir = “./”
folds = sub_dirs = np.array([‘aloe’,‘burger’,‘cabbage’,‘candied_fruits’,
‘carrots’,‘chips’,‘chocolate’,‘drinks’,‘fries’,
‘grapes’,‘gummies’,‘ice-cream’,‘jelly’,‘noodles’,‘pickles’,
‘pizza’,‘ribs’,‘salmon’,‘soup’,‘wings’])

获取特征feature以及类别的label

temp = extract_features(parent_dir,sub_dirs,max_file=100)
100%|██████████| 45/45 [00:11<00:00, 5.04it/s]
100%|██████████| 64/64 [00:14<00:00, 4.72it/s]
100%|██████████| 48/48 [00:15<00:00, 2.87it/s]
100%|██████████| 74/74 [00:26<00:00, 1.51it/s]
100%|██████████| 49/49 [00:14<00:00, 3.51it/s]
100%|██████████| 57/57 [00:16<00:00, 3.13it/s]
100%|██████████| 27/27 [00:07<00:00, 3.38it/s]
100%|██████████| 27/27 [00:07<00:00, 3.20it/s]
100%|██████████| 57/57 [00:15<00:00, 3.44it/s]
100%|██████████| 61/61 [00:17<00:00, 3.75it/s]
100%|██████████| 65/65 [00:20<00:00, 3.64it/s]
100%|██████████| 69/69 [00:21<00:00, 3.24it/s]
100%|██████████| 43/43 [00:12<00:00, 3.59it/s]
100%|██████████| 33/33 [00:08<00:00, 3.85it/s]
100%|██████████| 75/75 [00:23<00:00, 3.06it/s]
100%|██████████| 55/55 [00:17<00:00, 2.97it/s]
100%|██████████| 47/47 [00:14<00:00, 3.33it/s]
100%|██████████| 37/37 [00:11<00:00, 2.99it/s]
100%|██████████| 32/32 [00:07<00:00, 3.17it/s]
100%|██████████| 35/35 [00:10<00:00, 2.80it/s]
In [5]:
temp = np.array(temp)
data = temp.transpose()
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify ‘dtype=object’ when creating the ndarray
“”"Entry point for launching an IPython kernel.
In [6]:

获取特征

X = np.vstack(data[:, 0])

获取标签

Y = np.array(data[:, 1])
print(‘X的特征尺寸是:’,X.shape)
print(‘Y的特征尺寸是:’,Y.shape)
X的特征尺寸是: (1000, 128)
Y的特征尺寸是: (1000,)
In [7]:

在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示

Y = to_categorical(Y)
In [8]:
‘’‘最终数据’’’
print(X.shape)
print(Y.shape)
(1000, 128)
(1000, 20)
In [9]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)
print(‘训练集的大小’,len(X_train))
print(‘测试集的大小’,len(X_test))
训练集的大小 750
测试集的大小 250
In [10]:
X_train = X_train.reshape(-1, 16, 8, 1)
X_test = X_test.reshape(-1, 16, 8, 1)
02 建立模型
2.1 深度学习框架
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。现在Keras已经和TensorFlow合并,可以通过TensorFlow来调用。

2.1.1 网络结构搭建
Keras 的核心数据结构是 model,一种组织网络层的方式。最简单的模型是 Sequential 顺序模型,它由多个网络层线性堆叠。对于更复杂的结构,你应该使用 Keras 函数式 API,它允许构建任意的神经网络图。

Sequential模型可以直接通过如下方式搭建:

from keras.models import Sequential

model = Sequential()

In [11]:
model = Sequential()
2.1.2 搭建CNN网络
In [12]:

输入的大小

input_dim = (16, 8, 1)
2.1.3 CNN基础知识
推荐的资料中,我们推荐大家去看看李宏毅老师的讲的CNN网络这里也附上老师的PPT。

CNN网络的基本架构

图片.png

卷积神经网络CNN的结构一般包含这几个层:

1)输入层:用于数据的输入

2)卷积层:使用卷积核进行特征提取和特征映射------>可以多次重复使用

3)激励层:由于卷积也是一种线性运算,因此需要增加非线性映射(也就是激活函数)

4)池化层:进行下采样,对特征图稀疏处理,减少数据运算量----->可以多次重复使用

5)Flatten操作:将二维的向量,拉直为一维的向量,从而可以放入下一层的神经网络中

6)全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失----->DNN网络

对于Keras操作中,可以简单地使用 .add() ,将需要搭建的神经网络的layer堆砌起来,像搭积木一样:

In [13]:
model.add(Conv2D(64, (3, 3), padding = “same”, activation = “tanh”, input_shape = input_dim))# 卷积层
model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化
model.add(Conv2D(128, (3, 3), padding = “same”, activation = “tanh”)) #卷积层
model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层
model.add(Dropout(0.1))
model.add(Flatten()) # 展开
model.add(Dense(1024, activation = “tanh”))
model.add(Dense(20, activation = “softmax”)) # 输出层:20个units输出20个类的概率
如果需要,你还可以进一步地配置你的优化器.complies())。Keras 的核心原则是使事情变得相当简单,同时又允许用户在需要的时候能够进行完全的控制(终极的控制是源代码的易扩展性)。

In [14]:

编译模型,设置损失函数,优化方法以及评价标准

model.compile(optimizer = ‘adam’, loss = ‘categorical_crossentropy’, metrics = [‘accuracy’])
03 CNN模型训练与测试
3.1 模型训练
批量的在之前搭建的模型上训练:

In [15]:

训练模型

model.fit(X_train, Y_train, epochs = 90, batch_size = 50, validation_data = (X_test, Y_test))
Epoch 1/90
15/15 [] - 4s 162ms/step - loss: 2.9158 - accuracy: 0.1150 - val_loss: 2.7704 - val_accuracy: 0.1360
Epoch 2/90
15/15 [
] - 1s 79ms/step - loss: 2.5593 - accuracy: 0.2049 - val_loss: 2.5741 - val_accuracy: 0.2080
Epoch 3/90
15/15 [] - 1s 81ms/step - loss: 2.2751 - accuracy: 0.3186 - val_loss: 2.4504 - val_accuracy: 0.2520
Epoch 4/90
15/15 [
] - 1s 73ms/step - loss: 2.1422 - accuracy: 0.3405 - val_loss: 2.3872 - val_accuracy: 0.2680
Epoch 5/90
15/15 [] - 1s 76ms/step - loss: 1.9961 - accuracy: 0.3965 - val_loss: 2.3609 - val_accuracy: 0.2680
Epoch 6/90
15/15 [
] - 1s 81ms/step - loss: 1.8838 - accuracy: 0.4284 - val_loss: 2.4414 - val_accuracy: 0.2800
Epoch 7/90
15/15 [] - 1s 78ms/step - loss: 1.8825 - accuracy: 0.4246 - val_loss: 2.3338 - val_accuracy: 0.3000
Epoch 8/90
15/15 [
] - 1s 79ms/step - loss: 1.6590 - accuracy: 0.5138 - val_loss: 2.3595 - val_accuracy: 0.3000
Epoch 9/90
15/15 [] - 1s 81ms/step - loss: 1.5388 - accuracy: 0.5446 - val_loss: 2.4145 - val_accuracy: 0.3560
Epoch 10/90
15/15 [
] - 1s 78ms/step - loss: 1.4083 - accuracy: 0.5778 - val_loss: 2.3290 - val_accuracy: 0.3440
Epoch 11/90
15/15 [] - 1s 82ms/step - loss: 1.3643 - accuracy: 0.5991 - val_loss: 2.4037 - val_accuracy: 0.3320
Epoch 12/90
15/15 [
] - 1s 80ms/step - loss: 1.2137 - accuracy: 0.6554 - val_loss: 2.5388 - val_accuracy: 0.3280
Epoch 13/90
15/15 [] - 1s 81ms/step - loss: 1.1269 - accuracy: 0.6624 - val_loss: 2.5813 - val_accuracy: 0.3400
Epoch 14/90
15/15 [
] - 1s 82ms/step - loss: 1.1075 - accuracy: 0.6706 - val_loss: 2.6684 - val_accuracy: 0.3600
Epoch 15/90
15/15 [] - 1s 81ms/step - loss: 1.0023 - accuracy: 0.7228 - val_loss: 2.6690 - val_accuracy: 0.3560
Epoch 16/90
15/15 [
] - 1s 79ms/step - loss: 0.8535 - accuracy: 0.7726 - val_loss: 2.8743 - val_accuracy: 0.3560
Epoch 17/90
15/15 [] - 1s 79ms/step - loss: 0.8443 - accuracy: 0.7638 - val_loss: 2.8667 - val_accuracy: 0.3520
Epoch 18/90
15/15 [
] - 1s 74ms/step - loss: 0.7446 - accuracy: 0.7868 - val_loss: 2.9064 - val_accuracy: 0.3800
Epoch 19/90
15/15 [] - 1s 81ms/step - loss: 0.7462 - accuracy: 0.8109 - val_loss: 2.9133 - val_accuracy: 0.3960
Epoch 20/90
15/15 [
] - 1s 80ms/step - loss: 0.6511 - accuracy: 0.8115 - val_loss: 3.0051 - val_accuracy: 0.3920
Epoch 21/90
15/15 [] - 1s 81ms/step - loss: 0.6088 - accuracy: 0.8420 - val_loss: 3.0465 - val_accuracy: 0.3960
Epoch 22/90
15/15 [
] - 1s 82ms/step - loss: 0.5683 - accuracy: 0.8465 - val_loss: 3.0872 - val_accuracy: 0.3880
Epoch 23/90
15/15 [] - 1s 79ms/step - loss: 0.5308 - accuracy: 0.8631 - val_loss: 3.2055 - val_accuracy: 0.3840
Epoch 24/90
15/15 [
] - 1s 83ms/step - loss: 0.4531 - accuracy: 0.8883 - val_loss: 3.4579 - val_accuracy: 0.3800
Epoch 25/90
15/15 [] - 1s 79ms/step - loss: 0.4123 - accuracy: 0.8862 - val_loss: 3.3042 - val_accuracy: 0.4240
Epoch 26/90
15/15 [
] - 1s 81ms/step - loss: 0.3770 - accuracy: 0.9290 - val_loss: 3.5789 - val_accuracy: 0.4040
Epoch 27/90
15/15 [] - 1s 80ms/step - loss: 0.3534 - accuracy: 0.9032 - val_loss: 3.7284 - val_accuracy: 0.3880
Epoch 28/90
15/15 [
] - 2s 102ms/step - loss: 0.3603 - accuracy: 0.9148 - val_loss: 3.7052 - val_accuracy: 0.3920
Epoch 29/90
15/15 [] - 1s 74ms/step - loss: 0.2902 - accuracy: 0.9249 - val_loss: 3.7417 - val_accuracy: 0.4200
Epoch 30/90
15/15 [
] - 1s 84ms/step - loss: 0.2579 - accuracy: 0.9549 - val_loss: 3.7335 - val_accuracy: 0.4360
Epoch 31/90
15/15 [] - 1s 80ms/step - loss: 0.2289 - accuracy: 0.9466 - val_loss: 3.9161 - val_accuracy: 0.4240
Epoch 32/90
15/15 [
] - 1s 80ms/step - loss: 0.2180 - accuracy: 0.9463 - val_loss: 3.9646 - val_accuracy: 0.3960
Epoch 33/90
15/15 [] - 1s 81ms/step - loss: 0.2210 - accuracy: 0.9557 - val_loss: 4.0021 - val_accuracy: 0.4360
Epoch 34/90
15/15 [
] - 1s 81ms/step - loss: 0.2220 - accuracy: 0.9520 - val_loss: 3.9088 - val_accuracy: 0.4160
Epoch 35/90
15/15 [] - 1s 84ms/step - loss: 0.2364 - accuracy: 0.9426 - val_loss: 4.1504 - val_accuracy: 0.4120
Epoch 36/90
15/15 [
] - 1s 85ms/step - loss: 0.2370 - accuracy: 0.9434 - val_loss: 4.4365 - val_accuracy: 0.4200
Epoch 37/90
15/15 [] - 1s 84ms/step - loss: 0.2233 - accuracy: 0.9508 - val_loss: 4.2807 - val_accuracy: 0.4000
Epoch 38/90
15/15 [
] - 1s 79ms/step - loss: 0.1689 - accuracy: 0.9609 - val_loss: 4.5733 - val_accuracy: 0.4000
Epoch 39/90
15/15 [] - 1s 80ms/step - loss: 0.1540 - accuracy: 0.9628 - val_loss: 4.4454 - val_accuracy: 0.3920
Epoch 40/90
15/15 [
] - 1s 85ms/step - loss: 0.1966 - accuracy: 0.9445 - val_loss: 4.5280 - val_accuracy: 0.4120
Epoch 41/90
15/15 [] - 1s 85ms/step - loss: 0.1539 - accuracy: 0.9598 - val_loss: 4.7511 - val_accuracy: 0.4080
Epoch 42/90
15/15 [
] - 1s 79ms/step - loss: 0.1697 - accuracy: 0.9600 - val_loss: 4.5433 - val_accuracy: 0.4160
Epoch 43/90
15/15 [] - 1s 80ms/step - loss: 0.1852 - accuracy: 0.9558 - val_loss: 4.7979 - val_accuracy: 0.3920
Epoch 44/90
15/15 [
] - 1s 81ms/step - loss: 0.1319 - accuracy: 0.9735 - val_loss: 4.8103 - val_accuracy: 0.4120
Epoch 45/90
15/15 [] - 1s 81ms/step - loss: 0.1807 - accuracy: 0.9545 - val_loss: 4.5106 - val_accuracy: 0.4000
Epoch 46/90
15/15 [
] - 1s 78ms/step - loss: 0.1525 - accuracy: 0.9557 - val_loss: 4.6622 - val_accuracy: 0.4120
Epoch 47/90
15/15 [] - 1s 76ms/step - loss: 0.1094 - accuracy: 0.9735 - val_loss: 4.7476 - val_accuracy: 0.4240
Epoch 48/90
15/15 [
] - 1s 86ms/step - loss: 0.1285 - accuracy: 0.9639 - val_loss: 4.9710 - val_accuracy: 0.4120
Epoch 49/90
15/15 [] - 1s 74ms/step - loss: 0.1017 - accuracy: 0.9834 - val_loss: 4.7824 - val_accuracy: 0.4120
Epoch 50/90
15/15 [
] - 1s 81ms/step - loss: 0.1118 - accuracy: 0.9808 - val_loss: 5.0023 - val_accuracy: 0.4000
Epoch 51/90
15/15 [] - 1s 84ms/step - loss: 0.0734 - accuracy: 0.9847 - val_loss: 4.9060 - val_accuracy: 0.4440
Epoch 52/90
15/15 [
] - 1s 80ms/step - loss: 0.0770 - accuracy: 0.9823 - val_loss: 4.9116 - val_accuracy: 0.4320
Epoch 53/90
15/15 [] - 1s 82ms/step - loss: 0.0883 - accuracy: 0.9778 - val_loss: 5.0644 - val_accuracy: 0.4240
Epoch 54/90
15/15 [
] - 1s 73ms/step - loss: 0.0669 - accuracy: 0.9899 - val_loss: 4.9008 - val_accuracy: 0.4400
Epoch 55/90
15/15 [] - 1s 82ms/step - loss: 0.0530 - accuracy: 0.9905 - val_loss: 4.9777 - val_accuracy: 0.4320
Epoch 56/90
15/15 [
] - 1s 82ms/step - loss: 0.0622 - accuracy: 0.9921 - val_loss: 4.9766 - val_accuracy: 0.4440
Epoch 57/90
15/15 [] - 1s 86ms/step - loss: 0.0494 - accuracy: 0.9867 - val_loss: 5.1327 - val_accuracy: 0.4400
Epoch 58/90
15/15 [
] - 1s 87ms/step - loss: 0.0750 - accuracy: 0.9840 - val_loss: 5.2465 - val_accuracy: 0.4360
Epoch 59/90
15/15 [] - 1s 80ms/step - loss: 0.0760 - accuracy: 0.9803 - val_loss: 5.1679 - val_accuracy: 0.4120
Epoch 60/90
15/15 [
] - 2s 114ms/step - loss: 0.0773 - accuracy: 0.9776 - val_loss: 5.3310 - val_accuracy: 0.3960
Epoch 61/90
15/15 [] - 1s 80ms/step - loss: 0.0564 - accuracy: 0.9856 - val_loss: 5.1986 - val_accuracy: 0.4200
Epoch 62/90
15/15 [
] - 1s 81ms/step - loss: 0.0642 - accuracy: 0.9877 - val_loss: 5.2850 - val_accuracy: 0.3880
Epoch 63/90
15/15 [] - 1s 86ms/step - loss: 0.1085 - accuracy: 0.9804 - val_loss: 5.6972 - val_accuracy: 0.3920
Epoch 64/90
15/15 [
] - 1s 81ms/step - loss: 0.1160 - accuracy: 0.9661 - val_loss: 5.7879 - val_accuracy: 0.3840
Epoch 65/90
15/15 [] - 1s 80ms/step - loss: 0.1378 - accuracy: 0.9759 - val_loss: 5.5282 - val_accuracy: 0.4200
Epoch 66/90
15/15 [
] - 1s 80ms/step - loss: 0.1800 - accuracy: 0.9459 - val_loss: 5.7916 - val_accuracy: 0.3960
Epoch 67/90
15/15 [] - 1s 80ms/step - loss: 0.1467 - accuracy: 0.9514 - val_loss: 5.8140 - val_accuracy: 0.4120
Epoch 68/90
15/15 [
] - 1s 80ms/step - loss: 0.1248 - accuracy: 0.9747 - val_loss: 5.6973 - val_accuracy: 0.4200
Epoch 69/90
15/15 [] - 1s 77ms/step - loss: 0.0927 - accuracy: 0.9795 - val_loss: 5.4326 - val_accuracy: 0.4640
Epoch 70/90
15/15 [
] - 1s 83ms/step - loss: 0.0611 - accuracy: 0.9796 - val_loss: 5.6963 - val_accuracy: 0.4160
Epoch 71/90
15/15 [] - 1s 80ms/step - loss: 0.0786 - accuracy: 0.9800 - val_loss: 5.8339 - val_accuracy: 0.4280
Epoch 72/90
15/15 [
] - 1s 73ms/step - loss: 0.0620 - accuracy: 0.9864 - val_loss: 5.6282 - val_accuracy: 0.4400
Epoch 73/90
15/15 [] - 1s 79ms/step - loss: 0.0545 - accuracy: 0.9852 - val_loss: 5.4416 - val_accuracy: 0.4440
Epoch 74/90
15/15 [
] - 1s 74ms/step - loss: 0.0414 - accuracy: 0.9938 - val_loss: 5.6265 - val_accuracy: 0.4120
Epoch 75/90
15/15 [] - 1s 87ms/step - loss: 0.0502 - accuracy: 0.9837 - val_loss: 5.3705 - val_accuracy: 0.4560
Epoch 76/90
15/15 [
] - 1s 79ms/step - loss: 0.0462 - accuracy: 0.9899 - val_loss: 5.6978 - val_accuracy: 0.4320
Epoch 77/90
15/15 [] - 1s 77ms/step - loss: 0.0519 - accuracy: 0.9870 - val_loss: 5.7476 - val_accuracy: 0.4160
Epoch 78/90
15/15 [
] - 1s 79ms/step - loss: 0.0314 - accuracy: 0.9936 - val_loss: 5.9432 - val_accuracy: 0.4240
Epoch 79/90
15/15 [] - 1s 81ms/step - loss: 0.0422 - accuracy: 0.9861 - val_loss: 5.7963 - val_accuracy: 0.4000
Epoch 80/90
15/15 [
] - 1s 80ms/step - loss: 0.0473 - accuracy: 0.9871 - val_loss: 5.9414 - val_accuracy: 0.4280
Epoch 81/90
15/15 [] - 1s 80ms/step - loss: 0.0385 - accuracy: 0.9920 - val_loss: 5.9808 - val_accuracy: 0.4360
Epoch 82/90
15/15 [
] - 1s 73ms/step - loss: 0.0263 - accuracy: 0.9975 - val_loss: 5.8779 - val_accuracy: 0.4280
Epoch 83/90
15/15 [] - 1s 80ms/step - loss: 0.0227 - accuracy: 0.9983 - val_loss: 5.7883 - val_accuracy: 0.4360
Epoch 84/90
15/15 [
] - 1s 75ms/step - loss: 0.0258 - accuracy: 0.9945 - val_loss: 5.8290 - val_accuracy: 0.4280
Epoch 85/90
15/15 [] - 1s 85ms/step - loss: 0.0194 - accuracy: 0.9967 - val_loss: 5.7754 - val_accuracy: 0.4280
Epoch 86/90
15/15 [
] - 1s 86ms/step - loss: 0.0261 - accuracy: 0.9884 - val_loss: 5.7649 - val_accuracy: 0.4280
Epoch 87/90
15/15 [] - 1s 84ms/step - loss: 0.0240 - accuracy: 0.9927 - val_loss: 5.8440 - val_accuracy: 0.4320
Epoch 88/90
15/15 [
] - 1s 79ms/step - loss: 0.0196 - accuracy: 0.9913 - val_loss: 5.9228 - val_accuracy: 0.4280
Epoch 89/90
15/15 [] - 1s 74ms/step - loss: 0.0182 - accuracy: 0.9971 - val_loss: 5.9385 - val_accuracy: 0.4200
Epoch 90/90
15/15 [
] - 1s 80ms/step - loss: 0.0180 - accuracy: 0.9986 - val_loss: 5.9088 - val_accuracy: 0.4320
Out[15]:
<tensorflow.python.keras.callbacks.History at 0x7fca7ad7a6d8>
查看网络的统计信息

In [16]:
model.summary()
Model: “sequential”


Layer (type) Output Shape Param #

conv2d (Conv2D) (None, 16, 8, 64) 640


max_pooling2d (MaxPooling2D) (None, 8, 4, 64) 0


conv2d_1 (Conv2D) (None, 8, 4, 128) 73856


max_pooling2d_1 (MaxPooling2 (None, 4, 2, 128) 0


dropout (Dropout) (None, 4, 2, 128) 0


flatten (Flatten) (None, 1024) 0


dense (Dense) (None, 1024) 1049600


dense_1 (Dense) (None, 20) 20500

Total params: 1,144,596
Trainable params: 1,144,596
Non-trainable params: 0


3.2 预测测试集
新的数据生成预测

In [19]:
def extract_features(test_dir, file_ext="*.wav"):
feature = []
for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件
X, sample_rate = librosa.load(fn,res_type=‘kaiser_fast’)
mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
feature.extend([mels])
return feature
保存预测的结果

In [20]:
X_test = extract_features(’./test_a/’)
100%|██████████| 2000/2000 [10:13<00:00, 3.56it/s]
In [21]:
X_test = np.vstack(X_test)
predictions = model.predict(X_test.reshape(-1, 16, 8, 1))
In [22]:
preds = np.argmax(predictions, axis = 1)
preds = [label_dict_inv[x] for x in preds]

path = glob.glob(’./test_a/*.wav’)
result = pd.DataFrame({‘name’:path, ‘label’: preds})

result[‘name’] = result[‘name’].apply(lambda x: x.split(’/’)[-1])
result.to_csv(‘submit.csv’,index=None)
In [23]:
!ls ./test_a/*.wav | wc -l
2000
In [24]:
!wc -l submit.csv
2001 submit.csv

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值