我们用的数据集为train.csv,我们打开看一下
其中id是编号,species是种类,后面的margin1这些是他某一部分的特征,一共有三种大的特征类型(margin,shape,texture),三种特征都是从1-64(如margin1-margin64),特征以浮点数的形式记录了下来,我们使用后面的特征取预测前面的species
目录
1 导入库
2 整理数据
2.1 读取数据
首先我们读取数据
2.2 数据的基本情况
读取数据后我们看一下它的基本情况
一共990行,194列,格式有float64(浮点数)192列,int64(整数)1列(这个是id的那一列),object(字符串)1列(这个是species这一列),文件大小大致为1.5MB
2.3 提取训练输入与输出
我们一共有99种叶子,我们使用pd.factorize给每一个种类一个代表数字,它的运行结果是下面这样的,由于factorize是按照固定的顺序排列的,所以每次的运行结果都是一样的
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 4, 21, 22, 23, 24, 25, 26, 25, 15, 27, 28, 3, 29,
30, 31, 32, 33, 34, 35, 4, 2, 36, 14, 37, 38, 39, 40, 6, 41, 42,
43, 44, 45, 20, 27, 14, 32, 45, 34, 46, 29, 47, 48, 49, 50, 34, 51,
17, 4, 27, 52, 53, 35, 54, 53, 40, 13, 6, 55, 56, 39, 43, 57, 58,
59, 60, 34, 61, 62, 63, 62, 64, 62, 41, 65, 34, 49, 66, 11, 51, 6,
67, 68, 69, 63, 4, 38, 23, 68, 23, 70, 36, 17, 56, 56, 3, 27, 71,
46, 72, 73, 74, 52, 33, 75, 70, 76, 47, 77, 62, 78, 79, 67, 64, 77,
32, 56, 18, 45, 12, 61, 34, 59, 41, 5, 35, 15, 20, 13, 54, 80, 81,
22, 62, 16, 82, 37, 2, 4, 75, 69, 65, 11, 72, 31, 30, 83, 19, 38,
66, 18, 84, 75, 80, 82, 81, 57, 30, 46, 0, 20, 45, 24, 85, 17, 59,
65, 4, 71, 54, 82, 24, 38, 86, 15, 10, 5, 87, 56, 11, 88, 89, 90,
59, 36, 90, 54, 91, 7, 86, 52, 33, 57, 71, 58, 41, 74, 19, 35, 21,
86, 92, 21, 91, 75, 39, 35, 7, 48, 56, 78, 7, 84, 56, 81, 86, 52,
69, 32, 58, 24, 48, 77, 5, 23, 51, 87, 67, 54, 48, 28, 7, 23, 93,
35, 32, 77, 27, 1, 12, 80, 94, 15, 50, 78, 11, 20, 89, 91, 88, 26,
12, 8, 23, 19, 43, 65, 3, 38, 87, 51, 73, 23, 67, 73, 91, 63, 69,
30, 56, 92, 49, 95, 36, 11, 72, 89, 20, 76, 30, 8, 62, 80, 70, 72,
27, 88, 68, 41, 40, 63, 24, 26, 90, 15, 56, 59, 16, 18, 96, 73, 90,
83, 78, 51, 10, 27, 22, 42, 25, 25, 49, 64, 84, 32, 33, 28, 60, 39,
19, 78, 5, 91, 89, 34, 52, 70, 58, 28, 33, 52, 63, 82, 32, 59, 66,
24, 93, 7, 13, 6, 59, 67, 61, 69, 49, 63, 66, 16, 70, 37, 87, 77,
22, 77, 9, 87, 33, 15, 2, 68, 74, 90, 14, 88, 69, 73, 94, 67, 58,
8, 51, 86, 68, 3, 58, 58, 44, 19, 12, 81, 60, 26, 55, 97, 64, 96,
83, 11, 16, 65, 98, 73, 31, 50, 96, 44, 82, 48, 31, 4, 95, 78, 85,
47, 79, 29, 9, 82, 74, 29, 84, 69, 58, 98, 84, 86, 10, 31, 74, 27,
95, 12, 79, 40, 95, 88, 22, 26, 21, 19, 98, 8, 26, 11, 5, 5, 31,
95, 24, 35, 94, 37, 11, 13, 66, 1, 41, 39, 76, 89, 35, 93, 31, 81,
15, 85, 22, 29, 89, 20, 44, 83, 21, 47, 28, 94, 27, 53, 66, 98, 26,
2, 54, 22, 55, 51, 72, 18, 71, 57, 59, 80, 45, 76, 82, 65, 45, 50,
89, 76, 6, 52, 90, 33, 73, 83, 68, 14, 63, 31, 9, 6, 40, 85, 97,
36, 58, 71, 94, 71, 50, 42, 12, 40, 12, 76, 41, 43, 4, 98, 30, 83,
50, 44, 79, 5, 25, 92, 55, 27, 97, 37, 31, 69, 18, 16, 72, 55, 43,
8, 64, 46, 46, 91, 64, 81, 0, 29, 59, 61, 25, 38, 32, 57, 38, 75,
97, 20, 46, 65, 43, 93, 29, 83, 88, 37, 19, 75, 50, 1, 89, 73, 3,
36, 6, 57, 49, 16, 30, 13, 57, 40, 48, 17, 4, 62, 61, 1, 74, 39,
19, 21, 73, 25, 9, 7, 80, 84, 55, 69, 43, 36, 70, 39, 10, 19, 85,
66, 8, 62, 78, 0, 41, 21, 72, 22, 63, 49, 24, 87, 16, 20, 74, 62,
30, 1, 54, 9, 14, 67, 71, 88, 23, 77, 76, 23, 67, 17, 96, 29, 76,
5, 76, 87, 47, 3, 65, 48, 45, 74, 35, 9, 1, 9, 81, 2, 14, 17,
93, 17, 38, 12, 54, 32, 6, 21, 11, 53, 97, 46, 82, 33, 53, 79, 30,
39, 92, 96, 46, 86, 0, 7, 2, 59, 28, 13, 94, 10, 35, 53, 14, 77,
75, 67, 93, 79, 20, 13, 0, 98, 92, 79, 44, 82, 36, 37, 38, 2, 42,
84, 48, 34, 98, 39, 95, 41, 41, 66, 85, 80, 70, 52, 10, 98, 96, 15,
8, 71, 73, 85, 83, 87, 79, 55, 83, 71, 64, 96, 49, 78, 44, 60, 84,
93, 12, 55, 68, 46, 94, 31, 83, 68, 79, 97, 79, 67, 47, 88, 97, 84,
86, 26, 14, 30, 54, 28, 10, 53, 98, 0, 16, 93, 82, 18, 34, 45, 89,
80, 64, 74, 43, 55, 15, 65, 46, 22, 93, 18, 60, 3, 21, 18, 94, 51,
13, 60, 40, 75, 51, 0, 43, 92, 42, 61, 25, 44, 1, 97, 16, 3, 92,
97, 43, 50, 72, 81, 37, 61, 47, 22, 61, 52, 29, 97, 28, 75, 6, 49,
47, 80, 2, 39, 17, 74, 81, 40, 90, 70, 66, 91, 80, 94, 71, 33, 95,
14, 52, 28, 87, 77, 98, 21, 37, 92, 49, 92, 26, 1, 10, 42, 69, 42,
57, 89, 60, 91, 8, 34, 33, 7, 68, 2, 61, 84, 85, 96, 10, 60, 54,
94, 40, 0, 66, 36, 56, 95, 28, 26, 47, 24, 42, 53, 32, 9, 87, 60,
53, 91, 68, 65, 44, 48, 77, 23, 25, 75, 25, 42, 88, 50, 61, 64, 0,
88, 62, 8, 76, 96, 93, 24, 81, 90, 44, 1, 72, 58, 78, 38, 48, 70,
37, 36, 63, 91, 96, 51, 57, 63, 85, 9, 3, 78, 42, 90, 45, 86, 17,
92, 64, 50, 72, 85, 95, 29, 70, 18, 86, 90, 45, 60, 95, 53, 47, 5,
57, 55, 7, 13], dtype=int64), Index(['Acer_Opalus', 'Pterocarya_Stenoptera', 'Quercus_Hartwissiana',
'Tilia_Tomentosa', 'Quercus_Variabilis', 'Magnolia_Salicifolia',
'Quercus_Canariensis', 'Quercus_Rubra', 'Quercus_Brantii',
'Salix_Fragilis', 'Zelkova_Serrata', 'Betula_Austrosinensis',
'Quercus_Pontica', 'Quercus_Afares', 'Quercus_Coccifera',
'Fagus_Sylvatica', 'Phildelphus', 'Acer_Palmatum', 'Quercus_Pubescens',
'Populus_Adenopoda', 'Quercus_Trojana', 'Alnus_Sieboldiana',
'Quercus_Ilex', 'Arundinaria_Simonii', 'Acer_Platanoids',
'Quercus_Phillyraeoides', 'Cornus_Chinensis', 'Liriodendron_Tulipifera',
'Cytisus_Battandieri', 'Rhododendron_x_Russellianum', 'Alnus_Rubra',
'Eucalyptus_Glaucescens', 'Cercis_Siliquastrum', 'Cotinus_Coggygria',
'Celtis_Koraiensis', 'Quercus_Crassifolia', 'Quercus_Kewensis',
'Cornus_Controversa', 'Quercus_Pyrenaica', 'Callicarpa_Bodinieri',
'Quercus_Alnifolia', 'Acer_Saccharinum', 'Prunus_X_Shmittii',
'Prunus_Avium', 'Quercus_Greggii', 'Quercus_Suber',
'Quercus_Dolicholepis', 'Ilex_Cornuta', 'Tilia_Oliveri',
'Quercus_Semecarpifolia', 'Quercus_Texana', 'Ginkgo_Biloba',
'Liquidambar_Styraciflua', 'Quercus_Phellos', 'Quercus_Palustris',
'Alnus_Maximowiczii', 'Quercus_Agrifolia', 'Acer_Pictum',
'Acer_Rufinerve', 'Lithocarpus_Cleistocarpus',
'Viburnum_x_Rhytidophylloides', 'Ilex_Aquifolium', 'Acer_Circinatum',
'Quercus_Coccinea', 'Quercus_Cerris', 'Quercus_Chrysolepis',
'Eucalyptus_Neglecta', 'Tilia_Platyphyllos', 'Alnus_Cordata',
'Populus_Nigra', 'Acer_Capillipes', 'Magnolia_Heptapeta', 'Acer_Mono',
'Cornus_Macrophylla', 'Crataegus_Monogyna', 'Quercus_x_Turneri',
'Quercus_Castaneifolia', 'Lithocarpus_Edulis', 'Populus_Grandidentata',
'Acer_Rubrum', 'Quercus_Imbricaria', 'Eucalyptus_Urnigera',
'Quercus_Crassipes', 'Viburnum_Tinus', 'Morus_Nigra',
'Quercus_Vulcanica', 'Alnus_Viridis', 'Betula_Pendula', 'Olea_Europaea',
'Quercus_Ellipsoidalis', 'Quercus_x_Hispanica', 'Quercus_Shumardii',
'Quercus_Rhysophylla', 'Castanea_Sativa', 'Ulmus_Bergmanniana',
'Quercus_Nigra', 'Salix_Intergra', 'Quercus_Infectoria_sub',
'Sorbus_Aria'],
dtype='object'))
上面这个结果是一个元组,我们取元组的第0号元素作为我们的训练的label
除了第一列的id和第二列的species,剩下都是我们的输入,所以我们提取除前两列的所有列
2.4 分割训练集与测试集
之后我们分割出训练数据与测试数据,我们使用sklearn中的model_selecction的train_test_split方法,参数是上面定义的x与labels
- 上面我没有加入其他的参数,默认train数据0.75,test数据0.25,数据个数取整数
2.5 标准化
2.6 增加维度
由于我们之前都使用了Embedding层,所以没有在通过网络前对训练数据的维度进行更改,一维卷积或者LSTM层都需要输出形状为(samples,feature,step)这种形状的数据
- samples 训练数据个数
- feature 训练数据特征数
- step 一次看多少个
我们看一下此时train_x的shape
742是行数,192是列数,现在差一维step,所以我们要增加最后一维
我们现在再看一下
3 搭建网络
4 编译模型
5 训练模型
训练的时间不会很长,所以我们不加检查点了,但是有很多的epoch,所以我们加上tensorboard看一下情况,由于数据量不大,所以我们不设置batch了,当我们不设置batch时,所有的数据就是一个batch
训练之后保存模型
预测前我们看一下tensorboard,我们直接看acc
训练集90%以上,测试集90%左右,发现效果还可以
6 预测模型
这个是kaggle上的数据集,配套的还有一个test.csv,这个上面有除了名称之外的其他信息
我们就用这个来做预测,首先处理一下数据,然后做一个多分类预测
我们看一下结果,结果返回了一个列表,列表里都是数值,这个数值就是在训练过程中使用pd.fatorize的代表数值,我们可以将其映射一下就可以得到叶子的名字了
在test.csv中一共有594行数据
- 第1行是id这种标题
我们看一下我们预测的结果是不是594个
发现是594个,这表明我们的预测方法没有问题