写在前面
- 本次参加datawhale组织的语音识别比赛,主要是想体验一下流程,以及熟悉一下天池打比赛的环境。
- 今天花费了大量时间在天池建mxnet环境,企图白嫖GPU,报错了AttributeError,代码与本地相同,唯一区别是天池python是3.6,有空再试试。
BaseLine源码
解压训练集、测试集
!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip
!unzip -qq train_sample.zip
!\rm train_sample.zip
!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip
--2021-04-13 16:24:50-- http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip
Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194
Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1092637852 (1.0G) [application/zip]
Saving to: ‘test_a.zip’
100%[====================================>] 1,092,637,852 11.6MB/s in 88s
2021-04-13 16:26:19 (11.8 MB/s) - ‘test_a.zip’ saved [1092637852/1092637852]
!unzip -qq test_a.zip
!\rm test_a.zip
环境要求
- TensorFlow的版本:2.0 +
- keras
- sklearn
- librosa
# 基本库
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler
加载深度学习框架
# 搭建分类模型所需要的库
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
from numpy.core.umath_tests import inner1d
加载音频处理库
- 这里当时居然运行报错了,右上角切换了一下环境,不再报错
- conda list依然没有librosa,百思不得其解。
!pip install librosa --user
!conda list
Looking in indexes: https://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: librosa in /data/nas/workspace/envs/python3.6/site-packages (0.8.0)
Requirement already satisfied: numba>=0.43.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.53.1)
Requirement already satisfied: decorator>=3.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (4.4.2)
Requirement already satisfied: joblib>=0.14 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.0.0)
Requirement already satisfied: soundfile>=0.9.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.10.3.post1)
Requirement already satisfied: audioread>=2.0.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (2.1.9)
Requirement already satisfied: numpy>=1.15.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.19.4)
Requirement already satisfied: resampy>=0.2.2 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.2.2)
Requirement already satisfied: pooch>=1.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (1.3.0)
Requirement already satisfied: scipy>=1.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.5.4)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (0.24.0)
Requirement already satisfied: llvmlite<0.37,>=0.36.0rc1 in /data/nas/workspace/envs/python3.6/site-packages (from numba>=0.43.0->librosa) (0.36.0)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.6/site-packages (from numba>=0.43.0->librosa) (51.1.1)
Requirement already satisfied: appdirs in /data/nas/workspace/envs/python3.6/site-packages (from pooch>=1.0->librosa) (1.4.4)
Requirement already satisfied: requests in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (2.25.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (20.8)
Requirement already satisfied: six>=1.3 in /opt/conda/lib/python3.6/site-packages (from resampy>=0.2.2->librosa) (1.15.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.6/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)
Requirement already satisfied: cffi>=1.0 in /opt/conda/lib/python3.6/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.6/site-packages (from cffi>=1.0->soundfile>=0.9.0->librosa) (2.20)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.6/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (1.26.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2.10)
Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pillow-5.2.0-py36heded4f4_0.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/six-1.11.0-py36h372c433_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/idna-2.6-py36h82fb2a8_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/requests-2.18.4-py36he2e5f8d_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pycparser-2.18-py36hf9f622e_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/chardet-3.0.4-py36h0f667ec_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/python-graphviz-0.15-pyhd3eb1b0_0.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/wheel-0.30.0-py36hfd4bba0_1.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/urllib3-1.22-py36hbe7ace6_0.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pytorch-cpu-1.1.0-py3.6_cpu_0.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/setuptools-36.5.0-py36he42e2e1_0.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pip-9.0.1-py36h6c6f9ce_4.json. Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/numpy-base-1.15.4-py36h81de0dd_0.json. Please remove this file manually (you may need to reboot to free file handles)
# packages in environment at /opt/conda:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main defaults
_py-xgboost-mutex 2.0 cpu_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
absl-py 0.11.0 pypi_0 pypi
aliyun-python-sdk-core 2.13.5 pypi_0 pypi
aliyun-python-sdk-core-v3 2.13.3 pypi_0 pypi
aliyun-python-sdk-kms 2.7.1 pypi_0 pypi
argon2-cffi 20.1.0 pypi_0 pypi
asn1crypto 0.23.0 py36h4639342_0 defaults
astor 0.8.0 pypi_0 pypi
astunparse 1.6.3 pypi_0 pypi
async-generator 1.10 pypi_0 pypi
attrs 20.3.0 pypi_0 pypi
backcall 0.2.0 pypi_0 pypi
blas 1.0 mkl defaults
bleach 3.2.1 pypi_0 pypi
bzip2 1.0.8 h7b6447c_0 defaults
ca-certificates 2020.12.8 h06a4308_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
cachetools 4.2.1 pypi_0 pypi
cairo 1.14.12 h8948797_3 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
certifi 2020.12.5 pypi_0 pypi
cffi 1.14.4 pypi_0 pypi
chardet 4.0.0 pypi_0 pypi
cloudpickle 1.6.0 pypi_0 pypi
conda 4.9.2 py36h06a4308_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
conda-env 2.6.0 1 defaults
conda-package-handling 1.3.11 py36_0 defaults
crcmod 1.7 pypi_0 pypi
cryptography 2.3.1 py36hc365091_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
cycler 0.10.0 pypi_0 pypi
cython 0.29.21 pypi_0 pypi
dataclasses 0.8 pypi_0 pypi
decorator 4.4.2 pypi_0 pypi
defusedxml 0.6.0 pypi_0 pypi
dlib 19.21.1 pypi_0 pypi
dsw-demos-extension 0.1.0 pypi_0 pypi
dsw-ipykernel 0.2.0 pypi_0 pypi
dsw-magic 0.0.1 pypi_0 pypi
dsw-sql-extension 0.1.0 pypi_0 pypi
dswdlv 0.0.1 pypi_0 pypi
dswmagic 0.0.1 pypi_0 pypi
entrypoints 0.3 pypi_0 pypi
expat 2.2.5 he0dffb1_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
faiss-cpu 1.4.0 py36_cuda0.0_1 pytorch
fasttext 0.9.2 pypi_0 pypi
ffmpeg-python 0.2.0 pypi_0 pypi
flatbuffers 1.12 pypi_0 pypi
fontconfig 2.13.0 h9420a91_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
freetype 2.9.1 h8a8886c_1 defaults
fribidi 1.0.10 h7b6447c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
future 0.17.1 pypi_0 pypi
gast 0.3.3 pypi_0 pypi
glib 2.56.1 h000015b_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
google-auth 1.27.1 pypi_0 pypi
google-auth-oauthlib 0.4.3 pypi_0 pypi
google-pasta 0.2.0 pypi_0 pypi
graphite2 1.3.11 h16798f4_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
graphviz 2.40.1 h21bd128_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
grpcio 1.32.0 pypi_0 pypi
h5py 2.10.0 pypi_0 pypi
harfbuzz 1.8.4 hec2c2bc_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
icu 58.2 h9c2bf20_1 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
idna 2.10 pypi_0 pypi
imbalanced-learn 0.3.1 pyh2cb239c_0 glemaitre
importlib-metadata 3.3.0 pypi_0 pypi
intel-openmp 2019.4 243 defaults
ipykernel 5.4.2 pypi_0 pypi
ipython 7.9.0 pypi_0 pypi
ipython-genutils 0.2.0 pypi_0 pypi
jedi 0.18.0 pypi_0 pypi
jinja2 2.11.2 pypi_0 pypi
jmespath 0.9.4 pypi_0 pypi
joblib 1.0.0 pypi_0 pypi
jpeg 9b h024ee3a_2 defaults
json5 0.9.5 pypi_0 pypi
jsonschema 3.2.0 pypi_0 pypi
jupyter-client 6.1.7 pypi_0 pypi
jupyter-core 4.7.0 pypi_0 pypi
jupyterlab 2.2.8 pypi_0 pypi
jupyterlab-launcher 0.13.1 pypi_0 pypi
jupyterlab-prometheus 0.1 pypi_0 pypi
jupyterlab-pygments 0.1.2 pypi_0 pypi
jupyterlab-server 1.2.0 pypi_0 pypi
keras 2.2.4 pypi_0 pypi
keras-applications 1.0.8 pypi_0 pypi
keras-preprocessing 1.1.2 pypi_0 pypi
kiwisolver 1.2.0 pypi_0 pypi
libarchive 3.3.3 h7d0bbab_1 defaults
libedit 3.1 heed3624_0 defaults
libffi 3.2.1 hd88cf55_4 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgcc 7.2.0 h69d50b8_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgcc-ng 9.1.0 hdf63c60_0 defaults
libgfortran-ng 7.3.0 hdf63c60_0 defaults
libpng 1.6.37 hbc83047_0 defaults
libstdcxx-ng 7.2.0 h7a57d05_2 defaults
libtiff 4.0.9 he85c1e1_1 defaults
libuuid 1.0.3 h1bed415_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxcb 1.14 h7b6447c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxgboost 0.90 hf484d3e_1 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxml2 2.9.9 hea5a465_1 defaults
lightgbm 2.3.1 pypi_0 pypi
lz4-c 1.8.1.2 h14c3975_0 defaults
lzo 2.10 h49e0be7_2 defaults
markdown 3.1.1 pypi_0 pypi
markupsafe 1.1.1 pypi_0 pypi
matplotlib 3.3.3 pypi_0 pypi
mistune 0.8.4 pypi_0 pypi
mkl 2018.0.3 1 defaults
mkl_fft 1.0.4 py36h4414c95_1 defaults
mkl_random 1.0.1 py36h4414c95_1 defaults
nbclient 0.5.1 pypi_0 pypi
nbconvert 6.0.7 pypi_0 pypi
nbformat 5.0.8 pypi_0 pypi
ncurses 6.0 h9df7e31_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
nest-asyncio 1.4.3 pypi_0 pypi
ninja 1.8.2 py36h6bb024c_1 defaults
notebook 6.1.6 pypi_0 pypi
np-utils 0.5.10.0 pypi_0 pypi
numpy 1.19.4 pypi_0 pypi
oauthlib 3.1.0 pypi_0 pypi
odps 3.5.1 pypi_0 pypi
olefile 0.46 py36_0 defaults
open-from-url 0.1.0 pypi_0 pypi
openssl 1.0.2u h7b6447c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
opt-einsum 3.3.0 pypi_0 pypi
oss2 2.8.0 pypi_0 pypi
packaging 20.8 pypi_0 pypi
palettable 3.3.0 pypi_0 pypi
pandas 1.1.5 pypi_0 pypi
pandocfilters 1.4.3 pypi_0 pypi
pango 1.42.3 h8589676_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
parso 0.8.1 pypi_0 pypi
pcre 8.42 h439df22_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pexpect 4.8.0 pypi_0 pypi
pickleshare 0.7.5 pypi_0 pypi
pillow 8.0.1 pypi_0 pypi
pip 21.0.1 pypi_0 pypi
pixman 0.40.0 h7b6447c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
prometheus-client 0.9.0 pypi_0 pypi
prompt-toolkit 2.0.10 pypi_0 pypi
protobuf 3.15.5 pypi_0 pypi
ptyprocess 0.7.0 pypi_0 pypi
py-xgboost 0.90 py36hf484d3e_1 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pyasn1 0.4.8 pypi_0 pypi
pyasn1-modules 0.2.8 pypi_0 pypi
pybind11 2.6.1 pypi_0 pypi
pycosat 0.6.3 py36h27cfd23_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pycparser 2.20 pypi_0 pypi
pycryptodome 3.8.2 pypi_0 pypi
pygments 2.7.3 pypi_0 pypi
pymars 0.6.1 pypi_0 pypi
pyodps 0.10.3 pypi_0 pypi
pyopenssl 17.5.0 py36h20ba746_0 defaults
pyparsing 2.4.7 pypi_0 pypi
pyrsistent 0.17.3 pypi_0 pypi
pysocks 1.6.7 py36hd97a5b1_1 defaults
python 3.6.5 hc3d631a_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
python-dateutil 2.8.1 pypi_0 pypi
python-dotenv 0.15.0 pypi_0 pypi
python-graphviz 0.16 pypi_0 pypi
python-libarchive-c 2.8 py36_13 defaults
pytz 2020.5 pypi_0 pypi
pyyaml 5.1.2 pypi_0 pypi
pyzmq 20.0.0 pypi_0 pypi
readline 7.0 ha6073c6_4 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
requests 2.25.1 pypi_0 pypi
requests-oauthlib 1.3.0 pypi_0 pypi
rsa 4.7.2 pypi_0 pypi
ruamel_yaml 0.11.14 py36ha2fb22d_2 defaults
scikit-learn 0.24.0 pypi_0 pypi
scipy 1.5.4 pypi_0 pypi
seaborn 0.10.1 pypi_0 pypi
send2trash 1.5.0 pypi_0 pypi
setuptools 51.1.1 pypi_0 pypi
six 1.15.0 pypi_0 pypi
sklearn 0.0 pypi_0 pypi
sqlflow 0.15.0.dev0 pypi_0 pypi
sqlite 3.23.1 he433501_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
tensorboard 2.4.1 pypi_0 pypi
tensorboard-plugin-wit 1.8.0 pypi_0 pypi
tensorflow 1.14.0 pypi_0 pypi
tensorflow-cpu 2.4.0 pypi_0 pypi
tensorflow-estimator 2.4.0 pypi_0 pypi
tensorflow-io 0.7.0 pypi_0 pypi
termcolor 1.1.0 pypi_0 pypi
terminado 0.9.1 pypi_0 pypi
testpath 0.4.4 pypi_0 pypi
threadpoolctl 2.1.0 pypi_0 pypi
tianchi-extension 0.1.0 pypi_0 pypi
tk 8.6.10 hbc83047_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
torch 1.7.1 pypi_0 pypi
torchvision-cpu 0.3.0 py36_cuNone_1 pytorch
tornado 6.1 pypi_0 pypi
tqdm 4.32.1 py_0 defaults
traitlets 4.3.3 pypi_0 pypi
typing-extensions 3.7.4.3 pypi_0 pypi
urllib3 1.26.2 pypi_0 pypi
wcwidth 0.2.5 pypi_0 pypi
webencodings 0.5.1 pypi_0 pypi
werkzeug 0.15.5 pypi_0 pypi
wheel 0.36.2 pypi_0 pypi
wrapt 1.12.1 pypi_0 pypi
xz 5.2.4 h14c3975_4 defaults
yaml 0.2.5 h7b6447c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zipp 3.4.0 pypi_0 pypi
zlib 1.2.11 h7b6447c_3 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zstd 1.3.3 h84994c4_0 defaults
# 其他库
import os
import librosa
import librosa.display
import glob
特征提取以及数据集的建立
feature = []
label = []
# 建立类别标签,不同类别对应不同的数字。
label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,
'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,
'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,
'soup': 18, 'wings': 19}
label_dict_inv = {v:k for k,v in label_dict.items()}
from tqdm import tqdm
def extract_features(parent_dir, sub_dirs, max_file=10, file_ext="*.wav"):
c = 0
label, feature = [], []
for sub_dir in sub_dirs:
for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件
# segment_log_specgrams, segment_labels = [], []
#sound_clip,sr = librosa.load(fn)
#print(fn)
label_name = fn.split('/')[-2]
label.extend([label_dict[label_name]])
X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
feature.extend([mels])
return [feature, label]
# 自己更改目录
parent_dir = './train_sample/'
save_dir = "./"
folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits',
'carrots','chips','chocolate','drinks','fries',
'grapes','gummies','ice-cream','jelly','noodles','pickles',
'pizza','ribs','salmon','soup','wings'])
# 获取特征feature以及类别的label
temp = extract_features(parent_dir,sub_dirs,max_file=100)
100%|██████████| 45/45 [00:12<00:00, 5.03it/s]
100%|██████████| 64/64 [00:14<00:00, 5.09it/s]
100%|██████████| 48/48 [00:17<00:00, 2.88it/s]
100%|██████████| 74/74 [00:26<00:00, 1.31it/s]
100%|██████████| 49/49 [00:14<00:00, 3.50it/s]
100%|██████████| 57/57 [00:17<00:00, 3.65it/s]
100%|██████████| 27/27 [00:07<00:00, 3.48it/s]
100%|██████████| 27/27 [00:07<00:00, 3.54it/s]
100%|██████████| 57/57 [00:15<00:00, 3.67it/s]
100%|██████████| 61/61 [00:17<00:00, 4.01it/s]
100%|██████████| 65/65 [00:19<00:00, 3.11it/s]
100%|██████████| 69/69 [00:22<00:00, 3.08it/s]
100%|██████████| 43/43 [00:12<00:00, 3.41it/s]
100%|██████████| 33/33 [00:09<00:00, 3.37it/s]
100%|██████████| 75/75 [00:23<00:00, 3.15it/s]
100%|██████████| 55/55 [00:18<00:00, 2.96it/s]
100%|██████████| 47/47 [00:14<00:00, 3.50it/s]
100%|██████████| 37/37 [00:13<00:00, 2.04it/s]
100%|██████████| 32/32 [00:07<00:00, 3.87it/s]
100%|██████████| 35/35 [00:11<00:00, 2.76it/s]
temp = np.array(temp)
data = temp.transpose()
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
"""Entry point for launching an IPython kernel.
# 获取特征
X = np.vstack(data[:, 0])
# 获取标签
Y = np.array(data[:, 1])
print('X的特征尺寸是:',X.shape)
print('Y的特征尺寸是:',Y.shape)
X的特征尺寸是: (1000, 128)
Y的特征尺寸是: (1000,)
# 在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示
Y = to_categorical(Y)
'''最终数据'''
print(X.shape)
print(Y.shape)
(1000, 128)
(1000, 20)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)
print('训练集的大小',len(X_train))
print('测试集的大小',len(X_test))
训练集的大小 750
测试集的大小 250
X_train = X_train.reshape(-1, 16, 8, 1)
X_test = X_test.reshape(-1, 16, 8, 1)
建立模型
搭建CNN网络
model = Sequential()
# 输入的大小
input_dim = (16, 8, 1)
model.add(Conv2D(64, (3, 3), padding = "same", activation = "tanh", input_shape = input_dim))# 卷积层
model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化
model.add(Conv2D(128, (3, 3), padding = "same", activation = "tanh")) #卷积层
model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层
model.add(Dropout(0.1))
model.add(Flatten()) # 展开
model.add(Dense(1024, activation = "tanh"))
model.add(Dense(20, activation = "softmax")) # 输出层:20个units输出20个类的概率
# 编译模型,设置损失函数,优化方法以及评价标准
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 16, 8, 64) 640
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 4, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 8, 4, 128) 73856
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 2, 128) 0
_________________________________________________________________
dropout (Dropout) (None, 4, 2, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 1024) 1049600
_________________________________________________________________
dense_1 (Dense) (None, 20) 20500
=================================================================
Total params: 1,144,596
Trainable params: 1,144,596
Non-trainable params: 0
_________________________________________________________________
# 训练模型
model.fit(X_train, Y_train, epochs = 20, batch_size = 15, validation_data = (X_test, Y_test))
Epoch 1/20
50/50 [==============================] - 4s 56ms/step - loss: 2.9535 - accuracy: 0.1052 - val_loss: 2.6772 - val_accuracy: 0.1960
Epoch 2/20
50/50 [==============================] - 2s 37ms/step - loss: 2.4855 - accuracy: 0.2418 - val_loss: 2.5755 - val_accuracy: 0.2080
Epoch 3/20
50/50 [==============================] - 2s 37ms/step - loss: 2.2325 - accuracy: 0.3134 - val_loss: 2.4603 - val_accuracy: 0.2520
Epoch 4/20
50/50 [==============================] - 2s 39ms/step - loss: 2.0355 - accuracy: 0.3996 - val_loss: 2.4024 - val_accuracy: 0.2760
Epoch 5/20
50/50 [==============================] - 2s 38ms/step - loss: 1.8670 - accuracy: 0.4200 - val_loss: 2.4080 - val_accuracy: 0.3120
Epoch 6/20
50/50 [==============================] - 2s 37ms/step - loss: 1.6604 - accuracy: 0.4909 - val_loss: 2.4047 - val_accuracy: 0.3280
Epoch 7/20
50/50 [==============================] - 2s 37ms/step - loss: 1.5919 - accuracy: 0.5237 - val_loss: 2.5766 - val_accuracy: 0.3120
Epoch 8/20
50/50 [==============================] - 2s 38ms/step - loss: 1.3910 - accuracy: 0.5578 - val_loss: 2.6057 - val_accuracy: 0.3200
Epoch 9/20
50/50 [==============================] - 2s 37ms/step - loss: 1.2842 - accuracy: 0.6188 - val_loss: 2.6491 - val_accuracy: 0.3160
Epoch 10/20
50/50 [==============================] - 2s 37ms/step - loss: 1.0891 - accuracy: 0.6734 - val_loss: 2.9650 - val_accuracy: 0.3000
Epoch 11/20
50/50 [==============================] - 2s 38ms/step - loss: 1.0029 - accuracy: 0.6969 - val_loss: 2.9276 - val_accuracy: 0.3400
Epoch 12/20
50/50 [==============================] - 2s 37ms/step - loss: 0.8177 - accuracy: 0.7670 - val_loss: 3.0201 - val_accuracy: 0.3680
Epoch 13/20
50/50 [==============================] - 2s 38ms/step - loss: 0.7925 - accuracy: 0.7684 - val_loss: 3.2365 - val_accuracy: 0.3640
Epoch 14/20
50/50 [==============================] - 2s 39ms/step - loss: 0.7578 - accuracy: 0.7711 - val_loss: 3.6040 - val_accuracy: 0.3520
Epoch 15/20
50/50 [==============================] - 2s 38ms/step - loss: 0.6582 - accuracy: 0.8034 - val_loss: 3.4311 - val_accuracy: 0.3800
Epoch 16/20
50/50 [==============================] - 2s 45ms/step - loss: 0.6125 - accuracy: 0.8210 - val_loss: 3.4721 - val_accuracy: 0.3520
Epoch 17/20
50/50 [==============================] - 2s 38ms/step - loss: 0.5335 - accuracy: 0.8556 - val_loss: 3.8178 - val_accuracy: 0.3760
Epoch 18/20
50/50 [==============================] - 2s 37ms/step - loss: 0.4607 - accuracy: 0.8764 - val_loss: 3.7193 - val_accuracy: 0.3480
Epoch 19/20
50/50 [==============================] - 2s 39ms/step - loss: 0.4444 - accuracy: 0.8820 - val_loss: 3.8073 - val_accuracy: 0.3800
Epoch 20/20
50/50 [==============================] - 2s 37ms/step - loss: 0.3612 - accuracy: 0.9125 - val_loss: 3.8732 - val_accuracy: 0.3720
<tensorflow.python.keras.callbacks.History at 0x7ff66c0f7320>
预测测试集
def extract_features(test_dir, file_ext="*.wav"):
feature = []
for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件
X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
feature.extend([mels])
return feature
X_test = extract_features('./test_a/')
100%|██████████| 2000/2000 [10:28<00:00, 3.34it/s]
X_test = np.vstack(X_test)
predictions = model.predict(X_test.reshape(-1, 16, 8, 1))
preds = np.argmax(predictions, axis = 1)
preds = [label_dict_inv[x] for x in preds]
path = glob.glob('./test_a/*.wav')
result = pd.DataFrame({'name':path, 'label': preds})
result['name'] = result['name'].apply(lambda x: x.split('/')[-1])
result.to_csv('submit.csv',index=None)
!ls ./test_a/*.wav | wc -l
2000
!wc -l submit.csv