Kaggle入门笔记——基于CNN识别Mnist手写数字数据集
数据准备与实验流程
以Kaggle平台提供的train.csv和test.csv为数据源,链接如下:
链接:https://pan.baidu.com/s/1y6QVAXoBuuQ1IDMMp1eXYQ
提取码:hvlx
实验大致流程如图所示
Import and Define
from keras import models
from keras import layers
from keras import optimizers
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import pandas as pd
from plot_result import plot_results
import numpy as np
DataPath = 'F:/kaggle/Mnist/Dataset/train.csv' #train.csv路径
TestPath = 'F:/kaggle/Mnist/Dataset/test.csv' #test.csv路径
model_path = 'F:/kaggle/Mnist/Code/.idea/cnn_model.h5' #模型保存路径
Batchsize = 64
Epochs = 30
更改数据格式并划分数据集
def train_test_divided(filepath):
load_data = pd.read_csv(filepath) #读取train.csv
#train.csv第一列为数字类别标签,后面所有列均为像素值,故一共785列(784+1)
data = np.array(load_data.iloc[:,1:]) #从第二列开始读取为训练数据
label = np.array(load_data.iloc[:,0]) #读取第一列为数字类别标签
data = np.reshape(data, (-1,28,28,1)) #更改为网络输入类型,即原始图像shape(28*28*1)
train_data,test_data,train_label,test_label = train_test_split(data,label,test_size=0.25)
#按照3:1划分训练集和验证集
train_data = train_data.astype('float32')
test_data = test_data.astype('float32')
train_data = train_data/255.#输入数据标准化
test_data = test_data/255.
train_label = np.array(to_categorical(train_label))#对标签采用独热编码
test_label = np.array(to_categorical(test_label))
return train_data,train_label,test_data,test_label
构建CNN网络模型
def CNN_Model():
model = models.Sequential()
model.add(layers.Conv2D(32,(3, 3),activation='relu',input_shape=(28,28,1)))
model.add(layers.MaxPooling2D(2,2))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D(2,2))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.Dropout(0.1))
model.add(layers.Flatten())
model.add(layers.Dense(256,activation='relu'))
model.add(layers.Dropout(0.1))
model.add(layers.Dense(10,activation='softmax'))
return model
模型评价与结果
迭代30次,Batch_size设置为64,loss和acc如图所示
经过简单调参后,线下train_loss为0.0122,train_acc为0.9960,val_loss为0.0338,val_acc为0.9903,同时在kaggle平台在线分数可达0.99028
由于之前有接触过Mnist数据集分类,故在此做一个简短的总结,后期可能通过采用数据增强、在确保模型不会过拟合的情况下增大训练集数量或者更改网络结构等方法,均有可能使得分数更高。